/images/avatar.png

[转载] 谷歌开源计算框架JAX”

相信大家对numpy, Tensorflow, Pytorch已经极其熟悉,不过,你知道JAX吗? JAX发布之后,有网友进行了测试,发现,使用JAX,Numpy运算可以快三十多倍! 下面是使用Numpy的运行情况: 1 2 3 import numpy as np # 使用标准numpy,运算将在CPU上执行。 x = np.random.random([5000, 5000]).astype(np.float32) %timeit np.matmul(x, x) 运行结果: 1 2 1 loop, best of 3: 3.9 s per loop 而下面是使用JAX的Numpy的情况: import jax.numpy as np # 使用"JAX版"的numpy from jax import random # 注意JAX下随机数API有所不同 x = random.uniform(random.PRNGKey(0), [5000, 5000]) %timeit np.matmul(x, x) 运行情况: 1 1 loop, best of 3: 109 ms per loop 我们可以发现,使用原始numpy,运行时间大概为3.9s,而使用JAX的numpy,运行时间仅仅只有0.109s,速度上直接提升了三十多倍! 是不是很神奇? 那JAX到底是什么? JAX是谷歌开源的、可以在CPU、GPU和TPU上运行的numpy,是针对机器学习研究的高性能自微分计算框架。