🔗 英文原文: https://jax-ml.github.io/scaling-book/profiling/
✍️ 翻译: 北极的树
微信二维码 微信公众号

如何分析 TPU 程序

《如何扩展你的模型》第9部分 (第8部分:服务 LLaMA | 第10部分:JAX)

到目前为止,本系列文章完全是理论性的:基于硬件Roofline模型的粗略计算。这种理解能让你走得很远,但许多优化最终都归结于实践细节:XLA 编译器如何工作,以及如何使用像 JAX/Tensorboard Profiler 这样的分析工具来找出失败时的对策。本文将讨论这些内容。

TPU 软件栈的宏观视角

谷歌提供了多种用于 TPU 编程的 API,从高层的 JAX 代码到低层的 Pallas 或 HLO。大多数程序员只编写 JAX 代码,它允许你编写抽象的 NumPy 风格的线性代数程序,这些程序会被自动编译以在 TPU 上高效运行。

这里有一个简单的例子,一个将两个矩阵相乘的 JAX 程序:

import jax
import jax.numpy as jnp

def multiply(x, y):
  return jnp.einsum('bf,fd->db', x, y)

y = jax.jit(multiply)(jnp.ones((128, 256)), jnp.ones((256, 16), dtype=jnp.bfloat16))

通过调用 jax.jit,我们告诉 JAX 跟踪这个函数并生成一个名为 StableHLO 的底层中间表示(IR),这是一种用于机器学习计算的平台无关 IR,它又被 XLA 编译器降级为 HLO。编译器会运行多个 pass 来确定融合、布局和其他因素,从而产生在 JAX 性能分析中可观察到的 HLO。这个 HLO 以 LLVM 风格的图视图表示了 JAX 代码中所有的核心线性代数运算(矩阵乘法、逐点运算、卷积等)。例如,下面是上述程序的 HLO 删节版要获取此 HLO,你可以运行 `jax.jit(f).lower(*args, **kwargs).compile().as_text()`。

ENTRY %main.5 (Arg_0.1: f32[128,256], Arg_1.2: bf16[256,16]) -> f32[16,128] {
  %Arg_1.2 = bf16[256,16]{1,0} parameter(1), metadata={op_name="y"}
  %convert.3 = f32[256,16]{1,0} convert(bf16[256,16]{1,0} %Arg_1.2),
  %Arg_0.1 = f32[128,256]{1,0} parameter(0), metadata={op_name="x"}
  ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1), lhs_contracting_dims={0}, rhs_contracting_dims={1},
}

我们稍后会解释 HLO 的语法,但现在请注意,它实际上与上面的 JAX 代码相当匹配。例如,

ROOT %dot.4 = f32[16,128]{1,0} dot(f32[256,16]{1,0} %convert.3, f32[128,256]{1,0} %Arg_0.1), lhs_contracting_dims={0}, rhs_contracting_dims={1}

是上面实际的矩阵乘法,它分别沿着维度 0 和 1 将两个 f32 矩阵相乘。

为了将这个 HLO 转换为可以在 TPU 上执行的代码,XLA 编译器首先将其降级为 LLO(低级优化器)IR。LLO 直接对 TPU 进行编程,调度内存之间的拷贝、将数组推送到脉动阵列等。LLO 代码包含将缓冲区推入脉动阵列、取回结果以及调度在 TPU 不同内存部分之间通信的 DMA 的原语。一旦降级为 LLO,它就会被编译成机器码,加载到 TPU IMEM 中并执行。

当程序运行速度低于预期时,我们主要在 JAX 层面进行性能优化。然而,这样做通常需要我们理解 HLO 的一些语义以及代码在 TPU 上的实际运行方式。当在更底层出现问题时,我们会采取另一种应急方案,即在 Pallas 中编写自定义内核。要查看程序的 HLO 及其运行时统计信息,我们使用 JAX 分析器。

JAX Profiler:一个多功能的 TPU 分析器

JAX 提供了一个多功能的 TPU 分析器,其中包含许多有用的工具,可以帮助理解程序运行时在 TPU 上发生的情况。你可以使用 jax.profiler 模块来跟踪正在运行的程序,并记录从每个子组件的持续时间、每个程序的 HLO、内存使用情况等所有信息。例如,这段代码会将跟踪信息转储到 /tmp/tensorboard 目录下的一个文件中,该文件可以在 TensorBoard 中查看(这里是一个分步指南)。

import jax
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (1024, 1024))
  y = x @ x
  y.block_until_ready()

# Now you can load TensorBoard in a Google Colab with
#
# !pip install tensorboard tensorboard-plugin-profile
# %load_ext tensorboard
# %tensorboard --logdir=/tmp/tensorboard
#
# or externally with
#
# > tensorboard --logdir=/tmp/tensorboard
#

以下是你在分析器中可以做的事情的概述:

进入 TensorBoard 后,分析器有几个关键选项卡可以帮助你理解你的程序:

  1. Trace Viewer 以时间线的形式显示了 TPU 上实际发生情况的详细时间线。
  2. Graph Viewer 显示 HLO 图,让你能看到程序的哪些部分相互馈送,以及数据是如何分片的。
  3. Memory Profile 和 Memory Viewer: 这些显示了你的程序正在使用多少内存。

虽然分享性能分析文件有些困难,但这里有一个 Perfetto 链接,其中至少包含了一个简单 Transformer 的 Trace Viewer 组件。这个 Colab 让你能够生成完整的 JAX/TensorBoard 跟踪并进行实验。

Trace Viewer

Trace Viewer 可能是分析器中最有用的部分。 下面的例子展示了一个带有注释的简单 Transformer。名称来自代码中提供的标签。

Trace Viewer 显示了每个 TPU 核心上所有操作的时间顺序。这里我们只看 TPU:0,因为通常所有 TPU 都执行相同的指令。几个关键点:

  1. 顶行 (XLA Ops) 显示了实际的 TPU 操作(名称是 HLO 名称)。其他所有内容都是基于 jax.named_scopejax.named_call 和 Python 堆栈跟踪的近似跟踪。
  2. 注意到重复的块,我们可以在这里隔离出单个层。我们还可以(通过查看代码/理解 Transformer 的工作原理)看到哪些部分是注意力,哪些部分是 MLP。
  3. 通过点击一个 XLA 操作,我们可以查看它在代码中的来源(这对于理解跟踪很有用),并看到指向 Graph viewer 的链接。

提示: 你可以使用“视频游戏”风格的控制方式来导航 Trace Viewer,用 A/D 键左右平移,用 W/S 键放大和缩小。这些控制使得导航变得容易得多。

如何解读 XLA 操作

HLO 实际上并不难读,它对于理解上面跟踪中特定部分对应的内容非常有帮助。这里有一个名为 fusion.3 的操作示例。

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)} fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32), kind=kCustom, calls=%all-reduce-scatter.3

让我们把它分解成几个部分。

让我们更深入地理解这个表示法。让我们以这个简单的例子为例:

f32[3,5]{1,0:T(2,2)}

这同样告诉我们,这个操作返回一个形状为 [3, 5] 的 float32 数组,并带有一个特定的分块(tiling){1,0:T(2,2)}。虽然分块不是特别重要,但简而言之,分块告诉我们一个 N 维数组在内存中是如何顺序布局的。这里有一个图表展示了这个数组的布局方式:

{1,0:T(2,2)} 中,1,0 部分告诉我们数组维度在物理内存中的顺序,从最次要到最主要。你可以从右到左阅读这部分,并从 f32[3,5] 中找出相应的维度,以确定数组的物理布局。在这个例子中,物理布局是 [3,5],与逻辑形状相同。之后,T(2,2) 告诉我们数组以 (2, 2) 的块进行分块,在每个块内,数组首先是行(行主序),然后是列,即 (0, 0) 后面是 (0, 1),然后是 (1, 0)(1,1)。由于 T(2, 2) 分块,数组被填充到 [4, 6],使其内存使用量增加了约 1.6 倍。对于上面给出的大的 bf16 数组 bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)},我们有 T(8,128)(2,1),这告诉我们数组有两级分块,一个外部的 (8, 128) 分块和一个内部的 (2, 1) 分块(用于 bf16,以便我们的加载总是 4 字节的倍数)。例如,这里是 bf16[4,8]{1,0,T(2,4)(2,1)}(颜色是 (2,4) 块,红色框是 (2,1) 块):

分块会影响张量块加载到 VMEM 的效率,XLA 有时会在程序内部引入拷贝操作来“重新分块”或“重新布局”张量,这有时会带来不可忽略的开销。JAX 提供了一个实验性功能来解决这个问题,它允许 XLA 计算程序输入的“首选”布局。当你使用 `jax.jit` “即时”编译一个程序时,你通常会传入“模拟”输入,告诉 JAX 期望的形状和数据类型。这些输入通常也带有分块信息,但这可能不是最优的。相反,你可以将输入布局指定为 AUTO,`jax.jit` 将返回 JIT 编译后的程序偏好的布局。然后,你可以显式地以该布局加载张量,以避免在程序内引发拷贝。

Graph Viewer

虽然上面的一些融合操作看起来可能很复杂,但 XLA Graph Viewer 使它们更容易解析。例如,这是一个相当复杂的融合操作的视图:

盯着一堆 HLO 图,并尝试将 HLO 操作映射到你正在分析的代码上,这非常有帮助。将鼠标悬停在框上,你通常会看到定义该函数的代码行。

分析一个真实(近似)的性能分析示例

这个 Colab 有一个伪 Transformer 的性能分析示例。这里有一个 Perfetto 链接,至少可以查看 Trace Viewer 如果你赶时间的话。我比平时花了更多精力用 jax.named_scope 调用来注释跟踪,以便你能识别正在发生的事情。

看一下性能分析文件,试着真正理解每个部分在做什么。让我们把它分解一下,从 FFW 块开始:

这里我们放大了 FFW 块。你会看到上投影操作是一个融合(矩阵乘法),输入为 bf16[8, 1024, 8192]bf16[8192, 16384],输出为 bf16[32, 1024, 16384]。我知道(因为我写了这段代码)这是一个 4 路数据并行(DP)、2 路模型并行(MP)分片矩阵乘法的局部视图,所以我们实际上在做

X: bf16[32, 1024, 8192] * Win: bf16[8192, 32768] -> Tmp: bf16[32, 1024, 32768]

我们预计这需要多长时间? 首先,每个数据并行分片的批量大小是 8 * 1024 = 8192,所以我们应该完全受计算限制。这是在 8 个 TPUv2 核心上运行的(在 Google Colab 上免费提供),所以我们预计它大约需要 2 * 32 * 1024 * 8192 * 32768 / (23e12 * 8) = 95.6ms,这几乎与实际花费的时间(96ms)完全一样。太棒了!这意味着我们获得了非常好的 FLOPs 利用率!

通信方面呢? 你会注意到在第二个矩阵乘法末尾隐藏着一个小小的融合操作。如果我们点击它,你会看到

%fusion.1 = bf16[8,1024,4096]{2,1,0:T(8,128)(2,1)} fusion(bf16[8,1024,8192]{2,1,0:T(8,128)(2,1)} %fusion.31), kind=kCustom, calls=%all-reduce-scatter.1

这基本上是一个小型的 ReduceScatter(这是 GraphViewer);

我们预计这需要多长时间?嗯,我们正在一个 TPUv2 4x2 上执行 ReduceScatter,这在 1.2e11 的双向带宽上应该只需要一次跳跃。数组大小为 2*32*1024*8192,批量轴被分片为 4 路,所以每个分片是 2*8*1024*8192=134MB。因此,这大约需要 1.1ms。实际需要多长时间? 性能分析报告为 1.13ms。所以我们非常接近Roofline!

我们再来看看注意力! 这是注意力组件的性能分析:

我点击了 Q 投影操作,它使用了一个矩阵 WQ 形状为 [dmodel = 8192, nheads = 32, dqkv = 256]。我们正在沿着头维度进行 Megatron 分片。试着做同样的练习,计算这些操作应该花费多长时间。

内存分析

内存分析可以轻松地查看程序内存随时间的变化。这对于调试内存溢出(OOM)很有帮助。你可以在这里看到大约 7.5GB 分配给了模型参数,还有大约 10GB 的空闲内存。所以我们可以在内存中容纳更多的东西。

练习题

问题1:看一下这个 Colab/性能分析文件,找出可疑之处以及这里发生了什么。你能准确地告诉我正在进行什么计算,每个操作在做什么吗?涉及的每个矩阵的真实形状是什么,它们是如何分片的?试着先看性能分析文件,不要阅读代码。

点击这里查看答案。

这是两次矩阵乘法,具体来说是这样:

def matmul(w1, w2, x):
  return jnp.einsum('wf,bf->bw', w2, jnp.einsum('fw,bw->bf', w1, x))

你可以看到一个 reduce、两个大的融合操作和一个 all-reduce。第一个大的融合操作是:

%fusion.1 = bf16[4096]{0:T(1024)(128)(2,1)} fusion(bf16[4096,8192]{1,0:T(8,128)(2,1)} %param.1, bf16[8192]{0:T(1024)(128)(2,1)} %reduce.6), kind=kLoop, calls=%fused_computation.1

这告诉我们每个分片的形状是 bf16[8192] * bf16[4096, 8192] -> bf16[4096](在 8192 维度上)。通过观察最后的 AllReduce,其 replica_groups={{0,16,32,48,64,80,96,112}, ...},我们可以判断我们正在进行 8 路模型并行,所以真实的形状是 [8, 8192] * bf16[32,768, 8192] -> bf16[8, 32,768]

问题2: 前面的 Transformer Colab 实现了一个简单的模拟 Transformer。按照 Colab 中的说明,对使用 GSPMD 分区的朴素 Transformer 进行基准测试。每个部分需要多长时间?应该需要多长时间?正在使用哪种分片策略?尝试修复分片!提示:使用 jax.lax.with_sharding_constraints 来约束行为。修复后,你能获得的最佳 MXU 是多少?

作为参考,初始版本大约是 184ms/层,优化后的性能分析是 67ms/层。完成之后,试着盯着性能分析文件,看看你是否能仅凭性能分析文件回答这些问题:

注意: 自从写下这个问题以来,XLA 编译器已经变得更好了。初始版本现在大约是 90ms/层,优化后的性能分析只快了大约 10ms/层(80ms/层)。尽管如此,还是值得尝试一下,看看你是否能做得更好。

第9部分到此结束。要深入了解 JAX 并行性,请点击这里进入第10部分。

脚注

  1. 要获取此 HLO,你可以运行 `jax.jit(f).lower(*args, **kwargs).compile().as_text()`。[↩]
  2. JAX 提供了一个实验性功能来解决这个问题,它允许 XLA 计算程序输入的“首选”布局。当你使用 `jax.jit` “即时”编译一个程序时,你通常会传入“模拟”输入,告诉 JAX 期望的形状和数据类型。这些输入通常也带有分块信息,但这可能不是最优的。相反,你可以将输入布局指定为 AUTO,`jax.jit` 将返回 JIT 编译后的程序偏好的布局。然后,你可以显式地以该布局加载张量,以避免在程序内引发拷贝。[↩]

其他

*工作于 Google DeepMind 期间完成,现就职于 MatX。

引用

在学术背景下引用,请按如下方式引用本作品:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

或作为 BibTeX 条目:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }