微信公众号
《如何扩展你的模型》第10部分 (第9部分:性能分析 | 第11部分:结论)
如何使用 JAX 高效地为 TPU 编程!本节大部分内容摘自此处。您可以在 Google Colab 上使用免费的 TPU 运行本节中的代码示例。
JAX 支持三种多设备编程的思想流派:
| 模式 | 视图? | 显式分片? | 显式集合操作? |
|---|---|---|---|
| 自动 | 全局 | ❌ | ❌ |
| 显式 | 全局 | ✅ | ❌ |
| 手动 | 每设备 | ✅ | ✅ |
相应地,JAX 为每种模式都提供了 API:
jax.jit (使用 Auto 网格轴) 允许你使用任何现有的 JAX 函数,并用分片的输入来调用它。然后,JAX 会使用 XLA 的 Shardy 编译器来自动并行化程序。当需要支持现有操作时,XLA 会为你添加通信操作(AllGathers、ReduceScatters、AllReduces 等)。虽然它并不完美,但通常能在无需修改代码的情况下,很好地将你的程序自动扩展到任意数量的芯片上。jax.jit 使用 Explicit 网格轴看起来与(1)类似,但它让 JAX 而不是 XLA 来处理分片传播。这意味着数组的分片实际上是 JAX 类型系统的一部分,当 JAX 检测到模糊的通信时会报错,并让用户来解决。jax.shard_map 是更手动的对应方案。你获得的是程序的设备本地视图,并且必须显式地编写任何你想要的通信。有一个分片数组,并希望每个设备上都有完整的数据?添加一个 jax.lax.all_gather。想在所有设备上对一个数组求和?添加一个 jax.lax.psum (一个 AllReduce)。编程更难,但做错事的可能性要小得多。jax.jit 在 JAX 内部扮演两个角色。顾名思义,它会“即时”将一个函数从 Python 编译成字节码(通过 XLA/HLO/LLO),使其运行得更快。但如果输入是分片的,或者用户指定了 in_sharding 或 out_sharding,它还会让 XLA 将计算分布到多个设备上,并根据需要添加通信。例如,以下是如何使用 jax.jit 编写一个分片矩阵乘法:
import jax
import jax.numpy as jnp
# 在一个 TPU v5e 4x2 上运行。这为硬件的两个物理轴分配了名称。
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))
# 这告诉 JAX 对所有操作都使用这个网格,所以你只需指定 PartitionSpec P 即可。
jax.set_mesh(mesh)
# 我们创建一个矩阵 W 和输入激活 In,它们被分片到我们的设备上。
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('X', 'Y')))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('Y', None)))
def matmul_square(In, W):
return jnp.einsum('bd,df->bf', jnp.square(In), W)
# 我们可以在这里显式编译分片的矩阵乘法函数。这会添加所有
# 必要的通信(例如,在矩阵乘法之后进行 AllReduce)。
jit_matmul = jax.jit(matmul_square, out_shardings=jax.P('X', None)).lower(In, W).compile()
out = jit_matmul(In, W)
这将在任何分片策略下自动运行,并将计算分区到我们的设备上。但是在硬件层面到底发生了什么呢?
matmul_square 会简单地对输入进行平方运算并执行一个简单的矩阵乘法。但因为我们将 out_shardings 指定为 P('X', None),输出将沿着批次维度分片,但在模型维度上是复制的,并且需要一个 AllReduce 来计算。使用我们前面章节的表示法,这可能会执行类似以下的操作
jax.jit 会自动为我们添加这个!我们实际上可以用 jit_matmul.as_text() 打印出 HLO,并看到以下 HLO(经过大幅缩写):
# 这个融合操作是分片输入和矩阵的实际矩阵乘法
%fusion = bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} fusion(bf16[2,1024]{1,0:T(4,128)(2,1)} %param, bf16[8192,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done)
# 我们在设备间对部分求和的结果进行归约
ROOT %AllReduce = bf16[2,8192]{1,0:T(4,128)(2,1)} AllReduce(bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} %fusion)
我们可以看到上面的矩阵乘法(融合操作)和 AllReduce。请特别注意形状。bf16[2, 1024] 是激活的本地视图,因为我们的 batch_size=8 被分割到 4 个设备上,而我们的 d_model=2048 同样被分割成 2 路。
这简直太神奇了! 无论我们的程序多么复杂,Shardy 和 jit 都会尝试为所有中间激活找到分片方式,并根据需要添加通信。话虽如此,Shardy 也有其缺陷。它可能会犯错。有时你会查看性能剖析,发现出了问题。一个巨大的 AllGather 占用了 80% 的时间,而这本是不必要的。当这种情况发生时,我们可以尝试通过使用 jax.lax.with_sharding_constraint 显式地注释中间张量来纠正编译器。例如,对于两个矩阵乘法,我可以用以下方式强制中间激活沿着 y 维度进行分片(但这不一定是个好主意):
import jax
import jax.numpy as jnp
mesh = jax.make_mesh((4, 2), ('X', 'Y'))
def matmul(x, Win, Wout):
hidden = jnp.einsum('bd,df->bf', x, Win)
hidden = jax.lax.with_sharding_constraint(hidden, jax.P('x', 'y'))
return jnp.einsum('bf,df->bd', hidden, Wout)
在自动分区世界里,通过 jax.lax.with_sharding_constraint 控制中间分片构成了大约 60% 的 JAX 并行编程。但“挑逗编译器”是出了名的不好玩的编程模型。你可能注释了每个中间变量,但仍然不知道是否会得到正确的结果。那么,如果 JAX 本身能够处理和控制分片传播呢?
显式分片(或“类型中的分片”)看起来很像自动分片,但分片传播发生在 JAX 层面!每个 JAX 操作都有一个分片规则,它接收操作参数的分片方式,并为操作结果生成一个分片方式。你可以使用 jax.typeof 查看结果的分片:
import jax
import jax.numpy as jnp
import jax.sharding as shd
# 在一个 TPU v5e 2x2 上运行。这为硬件的两个物理轴分配了名称。
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'),
axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
# 这告诉 JAX 对所有操作都使用这个网格,所以你只需指定 PartitionSpec P 即可。
jax.set_mesh(mesh)
x = jax.device_put(np.arange(16).reshape(8, 2), P('X', 'Y'))
@jax.jit
def f(x):
print(jax.typeof(x)) # bfloat16[8@X,2@Y]
out = x * 2
print(jax.typeof(out)) # bfloat16[8@X,2@Y]
return out
f(x)
如你所见,JAX 将分片从输入 (x) 传播到了输出 (x),这可以在追踪时通过 jax.typeof 进行检查。对于大多数操作,这些规则简单明了,因为只有一个合理的选择(例如,逐元素操作保持相同的分片)。但对于某些操作,如何对结果进行分片是模糊的,在这种情况下,JAX 会抛出一个追踪时错误,并要求程序员显式地提供一个 out_sharding 参数(例如 jnp.einsum、jnp.reshape 等)。让我们看另一个有冲突的例子:
# 我们创建一个矩阵 W 和输入激活 In,它们被分片到我们的设备上。
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, out_sharding=jax.P('X', 'Y'))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, out_sharding=jax.P('Y', None))
@jax.jit
def matmul_square(In, W):
print(jax.typeof(In)) # bfloat16[8@X, 2048@Y]
print(jax.typeof(W)) # bfloat16[2048@Y, 8192]
return jnp.einsum('bd,df->bf', jnp.square(In), W)
matmul_square(In, W) # 这会报错
这段代码会报错 Contracting dimensions are sharded and it is ambiguous how the output should be sharded. Please specify the output sharding via the out_sharding parameter. Got lhs_contracting_spec=('Y',) and rhs_contracting_spec=('Y',)
这太棒了,因为 einsum 的输出应该如何分片是模糊的。输出分片可以是:
与自动模式不同,显式模式在检测到模糊的通信时会报错,并要求用户解决它。所以在这里你可以这样做:
@jax.jit
def matmul_square(In, W):
return jnp.einsum('bd,df->bf', jnp.square(In), W, out_sharding=P('X', 'Y'))
out = matmul_square(In, W)
print(jax.typeof(out)) # bfloat16[8@X,8192@Y]
自动模式和显式模式可以通过 jax.sharding.auto_axes 和 jax.sharding.explicit_axes API 组合使用。想了解更多信息,可以阅读这篇很棒的文档。
如果说 Shardy 是“编译器掌舵”模式,那么 jax shard_map 则将一切都交到你手中。你像在 jax.jit 中一样指定输入的分片,但之后你需要显式地编写所有通信。jax.jit 给你留下的是程序的全局跨设备视图,而 shard_map 给你的是一个本地的每设备视图。
这里有一个例子。试着推断一下这个函数是做什么的:
import jax
import jax.numpy as jnp
import jax.sharding as shd
mesh = jax.make_mesh((2, 4), ('x', 'y'), (shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)
x = jnp.arange(0, 512, dtype=jnp.int32, out_sharding=P(('x', 'y')))
# 这个函数将对数组的 1/8 进行操作。
@jax.shard_map(in_specs=P(('x', 'y')), out_specs=P())
def slice_and_average(x):
assert x.shape == (512 // 8,)
return jax.lax.pmean(x[:4], axis_name=('x', 'y'))
out = slice_and_average(x)
assert out.shape == (4,)
这是做什么的? slice_and_average 在每个 TPU 上运行,处理数组的 1/8,我们从中切片前 4 个元素,并在整个网格上对它们求平均值。这意味着我们实际上在做 mean(x[:4], x[64:68], x[128:132], …)。这非常酷,因为在 JAX 中用其他方式表达这个操作并不容易。
为什么要这样做而不是用 jax.jit? 如果我们使用 jax.jit,slice_and_average 将会看到数组的全局视图(完整的 [512,] 数组)。我们将不得不切出这个非均匀的切片,然后执行一个平均操作,而 XLA 必须正确地解释它。XLA 可能会添加错误的通信或感到困惑。在这里,我们看到的是本地视图,并且只编写我们需要的通信。
示例 [集合矩阵乘法]: 举一个更现实的例子,假设我们要实现模型并行性,其中激活最初是按模型分片的,即 A[BX, DY] * W[D, FY] -> Out[BX, FY]。天真地,我们会先对 A 进行 AllGather,然后进行一个本地矩阵乘法:
可惜,这样做不好,因为它不允许我们将通信与计算重叠。如 Wang et al. 2023 中所述,可以通过“集合矩阵乘法”来实现重叠。算法基本如下:
[B / X, F / Y] 的结果。同时,对 A 进行置换,以便在本地获得下一个块,执行矩阵乘法,并将结果相加。我们可以用 shard_map 相当容易地实现这一点:
import functools
import jax
import jax.numpy as jnp
import jax.sharding as shd
import numpy as np
mesh = jax.make_mesh(axis_shapes=(2, 4), axis_names=('X', 'Y'),
axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)
B, D, F = 1024, 2048, 8192
A = jnp.arange(np.prod((B, D))).reshape((B, D))
W = jnp.arange(np.prod((D, F))).reshape((D, F))
A = jax.device_put(A, jax.P('X', 'Y'))
W = jax.device_put(W, jax.P(None, 'Y'))
@functools.partial(jax.jit, out_shardings=jax.P('X', 'Y'))
def matmul(lhs, rhs):
return lhs @ rhs
def collective_matmul_allgather_lhs_contracting(lhs, rhs):
# lhs 是循环操作数;rhs 是本地操作数
axis_size = jax.lax.axis_size('Y') # 在这个例子中 axis_size = 4
idx = jax.lax.axis_index('Y')
chunk_size = lhs.shape[1]
assert rhs.shape[0] % chunk_size == 0
def f(i, carrys):
accum, lhs = carrys
rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
# 对一个块进行矩阵乘法
update = lhs @ rhs_chunk
# 向左循环移位
lhs = jax.lax.ppermute(
lhs,
axis_name='Y',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
return accum + update, lhs
accum = jnp.zeros((lhs.shape[0], rhs.shape[1]), dtype=lhs.dtype)
accum = jax.lax.pvary(accum, ('X', 'Y'))
accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs), unroll=True)
# 在最后一次置换后计算最后一个块,以使 lhs 恢复到我们找到它时的状态
i = axis_size - 1
rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
update = lhs @ rhs_chunk
return accum + update
jit_sharded_f = jax.jit(jax.shard_map(
collective_matmul_allgather_lhs_contracting,
in_specs=(jax.P('X', 'Y'), jax.P(None, 'Y')), out_specs=jax.P('X', 'Y')))
shmapped_out = jit_sharded_f(A, W)
expected_out = matmul(A, W)
np.testing.assert_array_equal(shmapped_out, expected_out)
这非常巧妙!我们可以对此进行基准测试,发现它也快得多!这里是默认 jit 矩阵乘法的性能剖析,它耗时 311us,并且在开始时有一个大的阻塞式 AllGather:
这里是上面那个版本的剖析,耗时 244 us。你可以看到剖析中没有 AllGather。全都是有效的工作!我们的 FLOPs 利用率也高得多。
还值得注意的是,在收缩维度上没有分片时的矩阵乘法时间是 224us,所以我们非常接近未分片的基线。这是一个很好的例子,说明了你可能需要进行何种性能工程来提高 TPU 的利用率。要了解更多 shard_map 示例,这篇笔记很棒。
现在这里有几个有用的实践问题,可以尝试用 jax.jit 或 shard_map 来实现!
这里有一些随机的 JAX 相关问题。我稍后会添加更多。对于所有这些问题,你都需要在 Colab 中有一定数量的 TPU。你可以使用带有 TPUv2-8 的公共 Colab。从现在开始,我们假设你有 N 个可用设备。
问题1:设 A 是一个激活数组,形状为 float32[SX, DY],其中 X * Y = N。请完成以下操作:
在 JAX 中编写一个函数,计算每个 (X, Y) 分片内的平均值,即返回一个大小为 [X, Y] 的数组,其中 arr[i, j] 是分片 (i, j) 上的平均值。分别用 jax.jit 和 shard_map 实现。对每个实现进行性能分析,看看它们耗时多久。是否添加了任何通信?提示:不应该有,但有时 XLA 还是会添加。
在 JAX 中编写一个函数,对于每个分片 X 内部的某个位移,返回 roll(x, shift, axis=0) - x。我还没那么自虐,不会让你用 jax.jit 来做这个,所以只用 shard_map 实现即可。
第1部分:这是第1部分的解答。请注意,对于 jax.jit 的解决方案,我们必须进行相当复杂的重塑操作。
import numpy as np
import jax
import jax.numpy as jnp
P = jax.sharding.PartitionSpec
mesh = jax.make_mesh((4, 2), ('X','Y'))
average_shmap = jax.shard_map(
lambda x: x.mean(keepdims=True),
mesh=mesh,
in_specs=P('X','Y'), out_specs=P('X','Y')
)
def average(x):
X, Y = mesh.axis_sizes
return x.reshape(X, x.shape[0] // X, Y, x.shape[1] // Y).mean(axis=(1, 3))
average_jit = jax.jit(average, out_shardings=jax.NamedSharding(mesh, P('X','Y')))
x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, P('X','Y')))
y1 = average_shmap(x)
y2 = average_jit(x)
np.testing.assert_array_equal(y1, y2)
第2部分:这是第2部分的类似解答。
import numpy as np
import jax
import jax.numpy as jnp
import functools
P = jax.sharding.PartitionSpec
mesh = jax.make_mesh((4, 2), ('X','Y'))
def shift_shmap(x, shift: int):
shmapped = jax.shard_map(
lambda x: jnp.roll(x, shift, axis=0),
mesh=mesh,
in_specs=P('X','Y'), out_specs=P('X','Y')
)
return shmapped(x)
@functools.partial(jax.jit, static_argnames=['shift'], out_shardings=jax.NamedSharding(mesh, P('X','Y')))
def shift_jit(x, shift: int):
X, Y = mesh.axis_sizes
reshaped = x.reshape(X, x.shape[0] // X, -1)
return jnp.roll(reshaped, shift, axis=1).reshape(x.shape[0], x.shape[1])
x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, P('X','Y')))
y1 = shift_shmap(x, 5)
y2 = shift_jit(x, 5)
np.testing.assert_array_equal(y1, y2)
问题2:在这里,我们将一起构建一个基本的“专家混合”(MoE)模型。设 W: float32[EX, D, FY] 是一组 E 个“专家”矩阵。设 A: float32[SX, DY](我们的激活)并且设 B 是一组“路由分配”,其中 B[i] 是范围 [0, E) 内的一个整数,告诉我们希望用哪个矩阵来处理该激活。我们想在 JAX 中编写一个函数,返回 Out[i] = W[B[i]] @ A[i]。
让我们先完全忽略分片。将所有这些张量做得足够小,以便它们能放入一个设备中。编写这个函数的本地实现。确保你不要物化一个形状为 [S, D, F] 的数组!提示:尝试将令牌排序到一个形状为 [E, S, D] 的新缓冲区中,并注意掩码(为什么我们需要第二个维度的大小为 S?)。
如果你只是对上述方法使用 jax.jit,会发生一些事情。对此进行性能分析,看看它决定进行何种通信。它需要多长时间?
你会注意到的一个问题是,上述方法很可能会在本地收集完整的激活集 A,即 AllGatherX([SX, DY])。这不仅在通信方面成本高昂,而且如果我们无法在本地容纳完整的激活集,在内存方面也是极其昂贵的。使用 shard_map 和显式通信来实现上述功能。
作为第一步,最简单的方法可能是使用一个 jax.lax.all_gather 并像(a)中那样重新排序。
作为第二步,尝试避免物化任何大小为 [E, S, D] 的数组,即尝试在一个 jax.lax.while_loop 内部使用一个 jax.lax.all_to_all 以不规则的方式执行计算。这样,你可以避免物化完整的激活并浪费计算在填充上。这比你最初的实现快多少?
大多数 MoE 模型会将输入路由到多个(k 个)专家,然后对结果进行平均。重构上述代码以实现这一点。在这种情况下,设 B: int32[S, k] 用于路由到 k 个专家。
问题3:上面那个集合矩阵乘法的例子实际上与真实的 LLM 非常相关。让我们调整这个例子来完成整个 Transformer 栈。
作为一个练习,让我们从实现一个 AllReduce 集合矩阵乘法开始,即 A[BX, DY] *D W[DY, F] -> Out[BX, F]。注意输出不是复制的。上面讨论了朴素算法,基本上就是一个本地矩阵乘法后跟一个 AllReduce。尝试制作一个通信重叠的“集合”版本的此操作。提示:在输出维度上进行分块,并可以随意使用 jax.lax.psum(即 AllReduce)。 注意:由于 XLA 处理此问题的方式,它实际上可能不会比基线更快。
上面 AllReduce 集合矩阵乘法的补充是 ReduceScatter 集合矩阵乘法,如 Tmp[BX, FY] *F W2[FY, D] -> Out[BX, DY]。这发生在 Transformer 中的下投影矩阵中。在 JAX 中实现一个集合的、重叠的版本。注意只传递所需的最少量数据。提示:尝试在累加结果时对其进行置换。
将这两者组合成一个端到端的 Transformer 块,该块执行 In[BX, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[BX, DY] 并带有重叠的通信。jax.jit 实现快多少?
问题4:上面实现的所有集合矩阵乘法都是单向的:它们只在一个方向上进行置换。重写集合 AllReduce 矩阵乘法和集合 ReduceScatter 矩阵乘法,以使用双向通信。它们快了多少?