JAX介绍和快速入门示例

共 6647字,需浏览 14分钟

 ·

2022-06-20 02:56


来源:DeepHub IMBA

本文约3300字,建议阅读10+分钟

本文中,我们了解了 JAX 是什么,并了解了它的一些基本概念。


JAX 是一个由 Google 开发的用于优化科学计算Python 库:

  • 它可以被视为 GPU 和 TPU 上运行的NumPy , jax.numpy提供了与numpy非常相似API接口。
  • 它与 NumPy API 非常相似,几乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。
  • 由于使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JIT编译成优化的内核,可以在不同设备(例如gpu和tpu)上运行。而优化的内核是为高吞吐量设备(例如gpu和tpu)进行编译,它与主程序分离但可以被主程序调用。JIT编译可以用jax.jit()触发。
  • 它对自动微分有很好的支持,对机器学习研究很有用。可以使用 jax.grad() 触发自动区分。
  • JAX 鼓励函数式编程,因为它是面向函数的。与 NumPy 数组不同,JAX 数组始终是不可变的。
  • JAX提供了一些在编写数字处理时非常有用的程序转换,例如JIT . JAX()用于JIT编译和加速代码,JIT .grad()用于求导,以及JIT .vmap()用于自动向量化或批处理。
  • JAX 可以进行异步调度。所以需要调用 .block_until_ready() 以确保计算已经实际发生。

JAX 使用 JIT 编译有两种方式:

  • 自动:在执行 JAX 函数的库调用时,默认情况下 JIT 编译会在后台进行。
  • 手动:您可以使用 jax.jit() 手动请求对自己的 Python 函数进行 JIT 编译。


JAX 使用示例


我们可以使用 pip 安装库。

pip install jax


导入需要的包,这里我们也继续使用 NumPy ,这样可以执行一些基准测试。

import jaximport jax.numpy as jnpfrom jax import randomfrom jax import grad, jitimport numpy as np
key = random.PRNGKey(0)

与 import numpy as np 类似,我们可以 import jax.numpy as jnp 并将代码中的所有 np 替换为 jnp 。如果 NumPy 代码是用函数式编程风格编写的,那么新的 JAX 代码就可以直接使用。但是,如果有可用的GPU,JAX则可以直接使用。

JAX 中随机数的生成方式与 NumPy 不同。JAX需要创建一个 jax.random.PRNGKey 。我们稍后会看到如何使用它。

我们在 Google Colab 上做一个简单的基准测试,这样我们就可以轻松访问 GPU 和 TPU。我们首先初始化一个包含 25M 元素的随机矩阵,然后将其乘以它的转置。使用针对 CPU 优化的 NumPy,矩阵乘法平均需要 1.61 秒。

# runs on CPU - numpysize = 5000x = np.random.normal(size=(size, size)).astype(np.float32)%timeit np.dot(x, x.T)# 1 loop, best of 5: 1.61 s per loop

在 CPU 上使用 JAX 执行相同的操作平均需要大约 3.49 秒。

# runs on CPU - JAXsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%timeit jnp.dot(x, x.T).block_until_ready()# 1 loop, best of 5: 3.49 s per loop

在 CPU 上运行时,JAX 通常比 NumPy 慢,因为 NumPy 已针对CPU进行了非常多的优化。但是,当使用加速器时这种情况会发生变化,所以让我们尝试使用 GPU 进行矩阵乘法。

# runs on GPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 102 µs, sys: 42 µs, total: 144 µs#   Wall time: 155 µs# 2. CPU times: user 1.3 s, sys: 195 ms, total: 1.5 s#   Wall time: 2.16 s# 3. 10 loops, best of 5: 68.9 ms per loop

从示例中可以看出,要进行公平的基准比较,我们需要使用 JAX 测量不同的步骤:

设备传输时间:将矩阵传输到 GPU 所经过的时间。耗时 0.155 毫秒。编译时间:JIT 编译经过的时间。耗时 2.16 秒。运行时间:有效的代码运行时间。耗时 68.9 毫秒。

在 GPU 上使用 JAX 进行单个矩阵乘法的总耗时约为 2.23 秒,高于 NumPy 的总时间 1.61 秒。但是对于每个额外的矩阵乘法,JAX 只需要 68.9 毫秒,而 NumPy 需要 1.61 秒,快了 22 倍多!因此,如果多次执行线性代数运算,那么使用 JAX 是有意义的。

让我们测试使用 TPU 进行矩阵乘法。

# runs on TPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 131 µs, sys: 72 µs, total: 203 µs#   Wall time: 164 µs# 2. CPU times: user 190 ms, sys: 302 ms, total: 492 ms#   Wall time: 837 ms# 3. 100 loops, best of 5: 16.5 ms per loop

忽略设备传输时间和编译时间,每个矩阵乘法平均需要 16.5 毫秒:GPU 相比快了4倍,与 CPU 的 NumPy相比快了88倍。需要说明的是,当乘以不同大小的矩阵时,获得相同的加速效果也不同:相乘的矩阵越大,GPU可以优化操作的越多,加速也越大。

为了在 Google Colab 上复制上述基准,需要运行以下代码让 JAX 知道有可用的 TPU。

import jax.tools.colab_tpujax.tools.colab_tpu.setup_tpu()

让我们看看 XLA 编译器。

XLA


XLA 是 JAX(和其他库,例如 TensorFlow,TPU的Pytorch)使用的线性代数的编译器,它通过创建自定义优化内核来保证最快的在程序中运行线性代数运算。XLA 最大的好处是可以让我们在应用中自定义内核,该部分使用线性代数运算,以便它可以进行最多的优化。

XLA 最重要的优化是融合,即可以在同一个内核中进行多个线性代数运算,将中间输出保存到 GPU 寄存器中,而不将它们具体化到内存中。这可以显著增加我们的“计算强度”,即所做的工作量与负载和存储数量的比例。融合还可以让我们完全省略仅在内存中shuffle 的操作(例如reshape)。

下面我们看看如何使用 XLA 和 jax.jit 手动触发 JIT 编译。

使用 jax.jit 进行即时编译

这里有一些新的基准来测试 jax.jit 的性能。我们定义了两个实现 SELU(Scaled Exponential Linear Unit)的函数:一个使用 NumPy,一个使用 JAX。暂时先不考虑 jax.jitat

def selu_np(x, alpha=1.67, lmbda=1.05):return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)
def selu_jax(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

然后,我们使用 NumPy 在 1M 个元素的向量上运行它。

# runs on the CPU - numpyx = np.random.normal(size=(1000000,)).astype(np.float32)%timeit selu_np(x)# 100 loops, best of 5: 7.6 ms per loop

平均需要 7.6 毫秒。现在让我们在 CPU 上使用 JAX。

# runs on the CPU - JAXx = random.normal(key, (1000000,))%time selu_jax(x).block_until_ready() # 1. measure JAX compilation time%timeit selu_jax(x).block_until_ready() # 2. measure JAX runtime# 1. CPU times: user 124 ms, sys: 5.01 ms, total: 129 ms#   Wall time: 124 ms# 2. 100 loops, best of 5: 4.8 ms per loop

现在平均需要 4.8 毫秒,在这种情况下比 NumPy 快。下一个测试是在 GPU 上使用 JAX。

# runs on the GPUx = random.normal(key, (1000000,))%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 103 µs, sys: 0 ns, total: 103 µs#   Wall time: 109 µs# 2. CPU times: user 148 ms, sys: 9.09 ms, total: 157 ms#   Wall time: 447 ms# 3. 1000 loops, best of 5: 1.21 ms per loop

函数运行时间为1.21毫秒。下面我们用 jax.jit 测试它,触发 JIT 编译器使用 XLA 将 SELU 函数编译到优化的 GPU 内核中,同时优化函数内部的所有操作。

# runs on the GPUx = random.normal(key, (1000000,))selu_jax_jit = jit(selu_jax)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax_jit(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax_jit(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 70 µs, sys: 28 µs, total: 98 µs#   Wall time: 104 µs# 2. CPU times: user 66.6 ms, sys: 1.18 ms, total: 67.8 ms#   Wall time: 122 ms# 3. 10000 loops, best of 5: 130 µs per loop

使用编译内核,函数运行时间为0.13毫秒!

让我们回顾一下不同的运行时间:

  • CPU 上的 NumPy:7.6 毫秒。
  • CPU 上的 JAX:4.8 毫秒(x1.58 加速)。
  • 没有 JIT 的 GPU 上的 JAX:1.21 毫秒(x6.28 加速)。
  • 带有 JIT 的 GPU 上的 JAX:0.13 毫秒(x58.46 加速)。

使用 JIT 编译避免从 GPU 寄存器中移动数据这样给我们带来了非常大的加速。一般来说在不同类型的内存之间移动数据与代码执行相比非常慢,因此在实际使用时应该尽量避免!

将 SELU 函数应用于不同大小的向量时,您可能会获得不同的结果。矢量越大,加速器越能优化操作,加速也越大。

除了执行 selu_jax_jit = jit(selu_jax) 之外,还可以使用 @jit 装饰器对函数进行 JIT 编译,如下所示。

@jitdef selu_jax_jit(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

JIT 编译可以加速,为什么我们不能全部都这样做呢?因为并非所有代码都可以 JIT 编译,JIT要求数组形状是静态的并且在编译时已知。另外就是引入jax.jit 也会带来一些开销。因此通常只有编译的函数比较复杂并且需要多次运行才能节省时间。但是这在机器学习中很常见,例如我们倾编译一个大而复杂的模型,然后运行它进行数百万次训练、损失函数和指标的计算。

使用 jax.grad 自动微分


另一个 JAX 转换是使用 jit.grad() 函数的自动微分。

借助 Autograd ,JAX 可以自动对原生 Python 和 NumPy 代码进行微分。并且支持 Python 的大部分特性,包括循环、if、递归和闭包。

下面看看一个带有 jit.grad() 的代码示例,我们计算一个自定义的包含 JAX 函数的Python 函数的导数。

def sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)derivative_fn = grad(sum_logistic)print(derivative_fn(x_small))# [0.25, 0.19661197, 0.10499357]


总结


在本文中,我们了解了 JAX 是什么,并了解了它的一些基本概念:NumPy 接口、JIT 编译、XLA、优化内核、程序转换、自动微分和函数式编程。在 JAX 之上,开源社区为机器学习构建了更多高级库,例如 Flax 和 Haiku。有兴趣的可以搜索查看。

编辑:黄继彦





浏览 7
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报