JAX介绍和快速入门示例
来源:DeepHub IMBA 本文约3300字,建议阅读10+分钟
本文中,我们了解了 JAX 是什么,并了解了它的一些基本概念。
它可以被视为 GPU 和 TPU 上运行的NumPy , jax.numpy提供了与numpy非常相似API接口。 它与 NumPy API 非常相似,几乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。 由于使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JIT编译成优化的内核,可以在不同设备(例如gpu和tpu)上运行。而优化的内核是为高吞吐量设备(例如gpu和tpu)进行编译,它与主程序分离但可以被主程序调用。JIT编译可以用jax.jit()触发。 它对自动微分有很好的支持,对机器学习研究很有用。可以使用 jax.grad() 触发自动区分。 JAX 鼓励函数式编程,因为它是面向函数的。与 NumPy 数组不同,JAX 数组始终是不可变的。 JAX提供了一些在编写数字处理时非常有用的程序转换,例如JIT . JAX()用于JIT编译和加速代码,JIT .grad()用于求导,以及JIT .vmap()用于自动向量化或批处理。 JAX 可以进行异步调度。所以需要调用 .block_until_ready() 以确保计算已经实际发生。
自动:在执行 JAX 函数的库调用时,默认情况下 JIT 编译会在后台进行。 手动:您可以使用 jax.jit() 手动请求对自己的 Python 函数进行 JIT 编译。
JAX 使用示例
pip install jax
import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit
import numpy as np
key = random.PRNGKey(0)
# runs on CPU - numpy
size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
# 1 loop, best of 5: 1.61 s per loop
# runs on CPU - JAX
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
# 1 loop, best of 5: 3.49 s per loop
# runs on GPU
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time
# 1. CPU times: user 102 µs, sys: 42 µs, total: 144 µs
# Wall time: 155 µs
# 2. CPU times: user 1.3 s, sys: 195 ms, total: 1.5 s
# Wall time: 2.16 s
# 3. 10 loops, best of 5: 68.9 ms per loop
设备传输时间:将矩阵传输到 GPU 所经过的时间。耗时 0.155 毫秒。
编译时间:JIT 编译经过的时间。耗时 2.16 秒。
运行时间:有效的代码运行时间。耗时 68.9 毫秒。
# runs on TPU
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time
# 1. CPU times: user 131 µs, sys: 72 µs, total: 203 µs
# Wall time: 164 µs
# 2. CPU times: user 190 ms, sys: 302 ms, total: 492 ms
# Wall time: 837 ms
# 3. 100 loops, best of 5: 16.5 ms per loop
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
XLA
def selu_np(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)
def selu_jax(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
# runs on the CPU - numpy
x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)
# 100 loops, best of 5: 7.6 ms per loop
# runs on the CPU - JAX
x = random.normal(key, (1000000,))
%time selu_jax(x).block_until_ready() # 1. measure JAX compilation time
%timeit selu_jax(x).block_until_ready() # 2. measure JAX runtime
# 1. CPU times: user 124 ms, sys: 5.01 ms, total: 129 ms
# Wall time: 124 ms
# 2. 100 loops, best of 5: 4.8 ms per loop
# runs on the GPU
x = random.normal(key, (1000000,))
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time selu_jax(x_jax).block_until_ready() # 2. measure JAX compilation time
%timeit selu_jax(x_jax).block_until_ready() # 3. measure JAX runtime
# 1. CPU times: user 103 µs, sys: 0 ns, total: 103 µs
# Wall time: 109 µs
# 2. CPU times: user 148 ms, sys: 9.09 ms, total: 157 ms
# Wall time: 447 ms
# 3. 1000 loops, best of 5: 1.21 ms per loop
# runs on the GPU
x = random.normal(key, (1000000,))
selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time
%time selu_jax_jit(x_jax).block_until_ready() # 2. measure JAX compilation time
%timeit selu_jax_jit(x_jax).block_until_ready() # 3. measure JAX runtime
# 1. CPU times: user 70 µs, sys: 28 µs, total: 98 µs
# Wall time: 104 µs
# 2. CPU times: user 66.6 ms, sys: 1.18 ms, total: 67.8 ms
# Wall time: 122 ms
# 3. 10000 loops, best of 5: 130 µs per loop
CPU 上的 NumPy:7.6 毫秒。 CPU 上的 JAX:4.8 毫秒(x1.58 加速)。 没有 JIT 的 GPU 上的 JAX:1.21 毫秒(x6.28 加速)。 带有 JIT 的 GPU 上的 JAX:0.13 毫秒(x58.46 加速)。
@jit
def selu_jax_jit(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
使用 jax.grad 自动微分
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
# [0.25, 0.19661197, 0.10499357]
总结
编辑:黄继彦
评论