🔗 英文原文: https://jax-ml.github.io/scaling-book/tpus/
✍️ 翻译: 北极的树
微信二维码 微信公众号

如何理解TPU

《如何扩展你的模型》第2部分 (第1部分:Roofline模型 | 第3部分:分片)

本节将全面介绍TPU的工作原理、它们如何通过网络连接以支持多芯片训练和推理,以及这些因素如何影响我们常用算法的性能。其中也有一些内容对GPU用户同样适用!

您可能也会对关于NVIDIA GPU的新增第12节感兴趣!

什么是TPU?

TPU本质上是一个专门用于矩阵乘法(称为TensorCore)的计算核心,附带一组高速内存(称为高带宽内存或HBM) 下图是一个示意图:

图:TPU芯片的基本组件。灰色的左侧方框是TensorCore,包含矩阵乘法单元(MXU)、向量单元(VPU)和向量内存(VMEM)。

你可以将TensorCore看作一台性能极佳的矩阵乘法机器,但它也有一些其他值得注意的功能。TensorCore有三个关键单元:

TPU在矩阵乘法方面非常非常快。这基本上是它们的主要工作,而且做得很好。TPU v5p是迄今为止最强大的TPU之一,每个核心每秒可以执行 2.5e14 bfloat16 FLOPs,或者说每个芯片每秒 5e14 bfloat16 FLOPs。一个包含8960个芯片的Pod每秒可以执行4 exaflops。这非常多。这是世界上最强大的超级计算机之一。而Google拥有很多这样的机器。TPU,特别是其脉动阵列,之所以是如此强大的硬件加速器,是因为矩阵乘法是少数几个使用 O(n^3) 计算量处理 O(n^2) 字节数据的算法之一。这使得普通的ALU很容易受计算而非内存带宽的限制。

上图还包括一些其他组件,如SMEM和标量单元,它们用于处理控制流,在附录A中有简要讨论,但理解它们并非至关重要。另一方面,HBM很重要且相当简单:

通常,所有TPU操作都是流水线化和重叠的。为了执行一次矩阵乘法 X \cdot A \to Y,TPU首先需要将矩阵 AX 的块从HBM复制到VMEM,然后将它们加载到MXU中,MXU会乘以8x128(对于X)和128x128(对于A)的块,然后将结果逐块复制回HBM。为了高效地完成这个过程,矩阵乘法是流水线化的,这样与VMEM之间的数据复制就可以与MXU的工作重叠。这使得MXU可以持续工作,而不是等待内存传输,从而使矩阵乘法受计算限制,而非内存限制。

这是一个从HBM执行逐元素乘法的例子:

图:一个动画展示了在TPU上执行的逐点乘法,字节从HBM加载。注意字节是如何以块的形式从内存中流出的,以及部分结果是如何在不等待整个数组物化的情况下流水线式地回传的。

一次矩阵乘法看起来几乎完全相同,只是它会加载到MXU而不是VPU/向量单元,并且加载和存储的顺序会不同,因为同一个权重块会用于多个激活块。你可以看到数据块流向VMEM,然后进入VREGs(向量寄存器),接着进入向量单元,最后回到VMEM和HBM。我们马上会看到,如果从HBM到VMEM的加载速度慢于向量单元(或MXU)的FLOPs,我们就会变得“带宽受限”,因为VPU或MXU会因缺少数据而空闲。

核心要点:TPU非常简单。它们将权重从HBM加载到VMEM,然后从VMEM加载到一个脉动阵列中,该阵列每秒可以执行约200万亿次乘加运算。HBM \leftrightarrow VMEM和VMEM \leftrightarrow 脉动阵列的带宽对TPU能高效执行哪些计算设定了基本限制。

VMEM和计算强度:VMEM比HBM小得多,但它与MXU之间的带宽要高得多。正如我们在第1节中看到的,这意味着如果一个算法能将其所有输入/输出都放入VMEM,它就更不容易遇到通信瓶颈。当一个计算的计算强度较低时,这一点尤其有用:VMEM带宽大约是HBM带宽的22倍,这意味着一个从VMEM读取/写入的MXU操作只需要10-20的计算强度就能达到峰值FLOPs利用率。这意味着如果我们能将权重放入VMEM而不是HBM,我们的矩阵乘法可以在小得多的批量大小下达到计算限制。这也意味着那些计算强度天生较低的算法仍然可以是高效的。只是VMEM太小了,这常常是一个挑战。我们有时会谈到VMEM预取,指的是提前在VMEM中加载权重,这样我们就可以掩盖矩阵乘法的加载成本。例如,在标准的Transformer中,我们有时可以在注意力计算期间将大的前馈网络权重加载到VMEM中,如果我们的内存带宽受限,这可以隐藏权重加载的成本。这要求我们的权重足够小或分片得足够细,以便能将单层权重放入VMEM并留有余地。

一个TPU芯片通常(但并非总是)由两个TPU核心组成,它们共享内存,可以被看作一个拥有两倍FLOPs的大型加速器(称为“megacore”配置)。自TPU v4以来都是如此。较早的TPU芯片有独立的内存,被视为两个独立的加速器(TPU v3及更早版本)。像TPU v5e这样为推理优化的芯片每个芯片只有一个TPU核心。

芯片4个一组的形式排列在“托盘”上,通过PCIe网络连接到一个CPU主机。这是大多数读者会熟悉的形式,即通过Colab或单个TPU-VM暴露的4个芯片(8个核心,但通常被视为4个逻辑megacore)。对于像TPU v5e这样的推理芯片,每个主机有2个托盘,而不是1个,但每个芯片也只有一个核心,因此我们有8个芯片 = 8个核心。在Cloud TPU VM上,每个托盘都作为独立VM的一部分暴露出来,因此再次可见4个核心。

PCIe带宽是有限的:与HBM \leftrightarrow VMEM链路一样,CPU \leftrightarrow HBM的PCIe连接有特定的带宽,这限制了你从主机内存加载到HBM或反之的速度。例如,TPU v4的PCIe带宽是单向16GB/秒,所以比HBM慢了近100倍。我们可以将数据加载/卸载到主机(CPU)RAM中,但速度不是很快。

TPU网络

在一个Pod中,芯片通过ICI网络相互连接。在较早的代际(TPU v2和TPU v3)、推理芯片(例如TPU v5e)和Trilium(TPU v6e)中,ICI(“芯片间互连”)连接4个最近的邻居(带有环绕链路以形成2D环面)。TPU v4和TPU v5p连接到最近的6个邻居(形成3D环面)。请注意,这些连接通过它们的主机,而是芯片之间的直接链路。

环面结构将任意两个节点之间的最大距离从N减少到N / 2,使得通信快得多。TPU还有一种“扭曲环面”配置,它以类似莫比乌斯带的拓扑结构包裹环面,以进一步减小节点间的平均距离。

TPU pod(通过ICI连接)可以变得非常大:最大的pod大小(称为superpod)对于TPU v4是 16x16x16,对于TPU v5p是 16x20x28。这些大型pod由可重构的 4x4x4 芯片立方体组成,通过光学环绕链路连接光学开关只是一个具有相同ICI带宽的可重构连接。它只是让我们在连接立方体的同时保留环绕链路。,我们可以重新配置它们以连接非常大的拓扑。

也可以请求较小的拓扑(例如2x2x1, 2x2x2),但没有环绕连接。这是一个重要的注意事项,因为它通常会使大多数通信时间加倍。任何完整立方体的倍数(例如4x4x44x4x8)都将由光开关提供环绕连接。请注意,一个 `2x2x4` 不会有任何环绕连接,因为它们是由光开关提供的,而光开关只在完整的立方体上可用。然而,一个TPU v5e 8x16 _将_在较长的轴上有环绕连接,因为它不使用可重构的光网络。

TPU v5e和Trillium pod由单个 16x16 2D环面组成,在任何大小为16的轴上都有环绕连接(这意味着一个 8x16 在长轴上有环绕连接)。TPU v5e和v6e(Trillium)不能扩展到超过16x16的环面,但pod之间仍然可以通过标准的数据中心网络(DCN)进行通信,DCN连接TPU主机。同样,可以请求没有环绕连接的较小拓扑,其维度<16

这种最近邻连接是TPU和GPU之间的一个关键区别。GPU通过一个分层的交换机网络连接,近似于每个GPU之间的点对点连接,而不是像TPU那样使用局部连接。通常,一个节点内的GPU(H100为8个GPU,B200多达500个)是直接连接的,而更大的拓扑则需要在每个GPU之间进行O(log(N))次跳跃。一方面,这意味着GPU可以在一个节点内以单次低延迟跳跃发送任意数据。另一方面,TPU的成本要低得多(因为NVLink交换机很昂贵),布线更简单,并且可以扩展到更大的拓扑,因为每个设备的链路数和每个设备的带宽是恒定的。在此处阅读更多信息here

ICI相对于DCN非常快,但仍比HBM带宽慢。例如,一个TPU v5p具有:

这意味着当我们在多个芯片上分割模型时,需要小心避免因较慢的跨设备通信而使MXU成为瓶颈。

多slice训练:一组通过ICI连接的TPU称为一个slice。不同的slice之间可以通过DCN连接,例如连接不同pod上的slice。由于DCN是比ICI慢得多的连接,我们应该尽量限制计算等待来自DCN数据的时间。DCN是主机到主机的,因此要通过DCN将缓冲区从一个TPU传输到另一个TPU,我们首先需要通过PCIe传输到主机,然后通过网络出口,再通过目标主机的网络入口,最后通过PCIe进入HBM。

核心要点

TPU规格

以下是我们芯片的一些具体数字:

型号 Pod大小 主机大小 HBM容量/芯片 HBM带宽/芯片 (字节/秒) FLOPs/秒/芯片 (bf16) FLOPs/秒/芯片 (int8)
TPU v3 32x32 4x2 32GB 9.0e11 1.4e14 1.4e14
TPU v4p 16x16x16 2x2x1 32GB 1.2e12 2.75e14 2.75e14
TPU v5p 16x20x28 2x2x1 96GB 2.8e12 4.59e14 9.18e14
TPU v5e 16x16 4x2 16GB 8.1e11 1.97e14 3.94e14
TPU v6e 16x16 4x2 32GB 1.6e12 9.20e14 1.84e15

主机大小指的是连接到单个主机的TPU拓扑(例如,TPU v5e有一个CPU主机连接到8个呈4x2拓扑的TPU)。以下是互连数据:

型号 ICI带宽/链路 (单向, 字节/秒) ICI带宽/链路 (双向, 字节/秒)
TPU v3 1e11 2e11
TPU v4p 4.5e10 9e10
TPU v5p 9e10 1.8e11
TPU v5e 4.5e10 9e10
TPU v6e 9e10 1.8e11

我们同时提供了单向(unidirectional)带宽和双向(bidirectional)带宽,因为单向带宽更接近硬件实际情况,但双向带宽在涉及完整环形网络的方程中更常出现。我们所说的双向(bidi)带宽是指单个链路上双向可发送的总字节数,或者等同于,假设我们可以高效地使用两个链路,单个TPU沿特定轴的总出向字节数。当我们有一个功能正常的环形网络时,即当我们在特定轴上有环绕连接时,这是成立的。这在推理芯片上当我们有一个完整的16轴时发生,或者在训练芯片(v*p)上当我们有一个是4的倍数的轴时发生。我们更喜欢使用双向带宽,因为它在涉及双向通信的计算中频繁出现。

PCIe带宽通常约为每个TPU 1.6e10 字节/秒(TPU v6e为 3.2e10),而DCN带宽通常约为每个TPU 6.25e9 字节/秒(TPU v6e为 12.5e9,TPU v5e为 3.125e9)。

练习题

这些数字可能有些枯燥,但它们能让你对模型性能做出基本的Roofline估算。让我们通过几个问题来解释为什么这很有用。你将在第3部分看到更多例子。

问题1 [LLM延迟的界定]:假设你想从一个分布在32个TPU v4p上的200B参数的bf16模型中进行采样。将所有参数从HBM加载到脉动阵列需要多长时间?提示:使用上面的数字。

点击此处查看答案。

答案:我们在32个芯片上加载 sizeof(bf16) * 200e9 = 400e9 字节,即每个芯片12.5e9字节,每个芯片的HBM带宽为1.23e12。所以加载大约需要10毫秒。

这很酷,因为这是对模型采样延迟的一个合理的下界。每个采样步骤都需要从HBM加载所有参数,所以耗时不可能少于10毫秒。实际上,在小批量大小下,这几乎是可以实现的。

问题2 [TPU细节]:考虑一个完整的TPU v5e pod。总共有多少个CPU主机?多少个TPU TensorCore?整个pod的总FLOPs/s是多少?总HBM是多少?对TPU v5p pod做同样的练习。

点击此处查看答案。

答案:对于TPU v5e,每个pod是 16x16,每个主机是一个4x2的slice,所以我们有 16*16 / 8 = 32 个主机。对于TPU v5e,每个TPU只有一个核心,所以我们有256个TensorCore。总FLOPs/s是 16*16*2e14 = 5.1e16 bfloat16。每个芯片有16GB的HBM,所以总内存是 256 * 16 = 4TB

对于一个完整的TPU v5p pod,我们有 16x20x28 个芯片,每个主机是2x2x1,所以我们有 16*20*28 / 2*2 = 2,240 个主机。对于TPU v5p,每个TPU有两个TensorCore,所以我们有 8960 * 2 = 17,920 个核心。总FLOPs/s是 8960 * 4.5e14 = 4e18 bfloat16。每个芯片有96GB的HBM,所以总内存是 8960 * 96 = 860TB

问题3 [PCIe操作强度]:想象我们被迫将一个大的权重矩阵 A(类型为 \text{bfloat16}[D, F])和一批激活 x(类型为 \text{bfloat16}[B, D])存储在主机DRAM中,并想在它们上面进行矩阵乘法。这在单个主机上运行,我们使用一个与之相连的TPU v6e芯片。你可以假设 B \ll D,和 F = 4D(在未来的章节中,我们会看到为什么这些是合理的假设)。我们需要多大的最小批量大小 B 才能在PCIe上保持计算限制?假设PCIe带宽为1.5e10字节/秒。

点击此处查看答案。

答案:我们需要执行 2BDF 次浮点运算,每个芯片每秒可以执行 9.2e14 次浮点运算。这需要 2BDF / 9.2e14 秒来完成。我们需要从DRAM加载 2DF + 2BD 字节,并将 2BF 字节写回。我们受限于PCIe传输速度,所以需要 2 \cdot (BD + DF + BF) / 1.5e10 秒来与TPU进行数据传输。因为我们希望计算时间比权重加载时间长,假设我们可以将所有权重加载与计算重叠,我们希望 2BDF / 9.2e14 > 2 \cdot (BD + DF + BF) / 1.5e10。我们可以使用我们的假设 B \ll D,和 F = 4D 来简化这个不等式,得到

8BD29.2e14>8D21.5e10

或者

B>9.2e141.5e1061,000

问题4 [通用矩阵乘法延迟]:假设我们想将一个权重矩阵int8[16384, 4096]与一个大小为int8[B, 4096]的激活矩阵相乘,其中B是某个未知的批量大小。假设我们开始时在一个TPUv5e上。

  1. 这次乘法作为B的函数将耗时多久?提示:计算从HBM加载数组所需的时间和实际乘法所需的时间可能会有帮助。哪个是你的瓶颈?
  2. 如果我们想从VMEM中运行这个操作呢?作为B的函数它将耗时多久?
点击此处查看答案。

答案:(1) 我们需要执行的浮点运算次数是 2 \cdot 4096 \cdot 16384 \cdot B = 1.3e8 \cdot B。所以 T_{\text{math}} = (1.3e8 \cdot B) / 3.94e14 秒。我们需要从HBM加载 16384 \cdot 4096 + 4096 \cdot B 字节到VMEM,并将 16384 \cdot B 字节从VMEM写回HBM。这意味着 T_{\text{comms}} = (6.7e7 + 2e4\cdot B) / 8.1e11 秒。假设通信和计算尽可能重叠,整个乘法将大约耗时

max{Tmath,Tcomms}=max{6.7e7+2e4B8.1e11,1.3e8B3.94e14}

\frac{6.7e7 + 2e4\cdot B}{8.1e11} < \frac{1.3e8 \cdot B}{3.94e14} 时,我们将受限于FLOPs,或者等价地,B > 271。这个数字比我们下面推导出的240略大,因为我们考虑了 DF 的全部影响。

(2) 如果我们改为从VMEM加载,让我们将VMEM到MXU的带宽视为HBM \leftrightarrow VMEM带宽的22倍。这将我们的数据加载分母从8.1e11变为1.78e13,我们得到 B > 11。注意,实际上,我们不能将所有VMEM带宽都用于加载 W,所以实际上它会接近20。

问题5 [ICI带宽]:假设我们有一个TPU v5e 4x4 slice。我们想将一个类型为 bfloat16[8, 128, 8192] 的数组从 TPU{0,0} 发送到 TPU{3, 3}。假设TPU v5e的每跳延迟是 1\mu s

  1. 第一个字节将多快到达目的地?
  2. 总传输将耗时多久?
点击此处查看答案。

答案:在TPUv5e中,我们有2D连接。因为我们只有一个 4x4 slice(没有大小为16的轴),我们没有环绕连接。因此,我们的目标芯片可以从两个端口接收数据,同样地,我们的源芯片可以从两个端口发送数据。我们需要传输的数据量是 2 * 8 * 128 * 8192 = 1.7e7 字节。我们可以同时从两个端口传输(即一半数组向右发送,一半向下发送),所以我们每秒传输 2 * 4.5e10 = 9e10 字节,这意味着传输整个数组大约需要 1.7e7 / 9e10 = 188us(假设我们受限于带宽)。在一个 4x4 slice中,芯片 (0, 0)(3, 3) 之间有六次跳跃,因为轴上少于16个芯片时没有环绕链路。由于每跳的延迟大约是 1\mu s,第一个字节将在大约6us到达,总传输将耗时188us

问题6 [综合练习,较难]:想象你有一个大矩阵 Aint8[128 * 1024, 128 * 1024] 均匀分片在一个TPU v5e 4x4 slice上,但卸载到每个芯片的主机DRAM中。假设你想将整个数组复制到TPU{0, 0}并将其与一个向量 bf16[8, 128 * 1024] 相乘。这将耗时多久?提示:使用上面的数字。

点击此处查看答案。

答案:让我们首先概述一下我们需要执行的操作。我们的数组大约是16GB。从上表中,一个TPU v5e主机有4x2的拓扑,所以一个4x4有2个主机。因此,由于我们的数组是均匀分片的,每个主机实际上包含数组的1/2,即8GB。我们需要将这些块全部复制到TPU{0,0},这给了我们两个选择:

  1. 我们可以通过DCN复制,然后通过PCIe将整个未分片的数组加载到HBM中。
  2. 我们可以将分片的数组加载到它们对应的TPU上,然后在ICI上执行一次gather操作,最后在TPU{0,0}上执行矩阵乘法。

很明显,选项(2)更好。DCN比ICI慢,而且我们更愿意通过多个PCIe链路加载一个大数组,而不是仅仅通过少数几个(主机0上的8个)。这是系统部分示意图。如上所述,请注意TPU通过ICI连接到它们的邻居(即使跨主机),所有TPU都通过PCIe连接到它们的主机CPU,而主机之间通过DCN连接。

每个芯片实际上都有自己的PCIe链路连接到其主机,但为了清晰起见,这里只显示了一个。

现在让我们计算每个部分将耗时多久:

  1. PCIe加载:我们通过16个PCIe链路加载16GB / 2 = 8GB的块,每个链路的带宽为 1.5e10 字节/秒。因此,这将耗时约33毫秒。

  2. ICI复制:每个TPU现在拥有我们数组的16GB / 16 = 1GB。我们的ICI带宽是每个链路双向9e10字节/秒,你会从上图中注意到,在这个拓扑中,TPU v5e上4个ICI链路中只有2个被TPU{0,0}使用。由于TPU{0,0}需要沿着2个轴以 4.5e10 字节/秒/链路的速度接收总共15GB的数据,我们可以将时间下界限定为 15e9 / (4.5e10 * 2) = 167ms。实际上,这可能无法实现,因为负载非常不均匀,但可能在2倍的因子内。正如你将在第2节中看到的,执行一个完整的AllGather也大约需要 16e9 / (4.5e10 * 2),所以这接近最优。

  3. HBM \rightarrow MXU加载:为了执行我们最后的矩阵乘法,我们需要将这16e9字节加上bf16[8, 128 * 1024]数组(另外2MB,可忽略不计)通过HBM带宽加载到MXU中,这将耗时 16e9 / 8.1e11 = 19ms

  4. FLOPs:我们总共执行 2812810241281024=2.7e11 FLOPs,由于我们可以执行 1.97e14 bf16 FLOPs/s,我们得到1.3毫秒。

总时间的上限是所有这些时间的总和,但由于TPU通常可以重叠这些操作,我们可以将其视为一个由最慢部分决定的流水线问题。如果这是真的,那么答案大约是150-200毫秒。

第2部分到此结束!关于第3部分,涵盖分区和跨TPU通信,请点击这里

附录

附录A:更多关于TPU内部的细节

在这里,我们将更深入地探讨TPU的内部操作。除非另有说明,我们将提供TPU v5p的规格。

VPU

VPU是TPU的向量算术核心。VPU由一个二维SIMD向量机(VPU)组成,该向量机执行逐元素算术运算,如vadd(向量加法)或vmax(逐元素最大值),以及一组称为VREGs的向量寄存器,用于为VPU和MXU保存数据。

VREGs:每个TPU v5p核心有64个32位VREGs(TPU v4为32个),这给了我们每个核心总共约 64 * 8 * 128 * 4 = 256kB 的VREG内存(对于整个芯片来说是这个值的2倍,因为我们有两个核心)。一个TPU v5p每个周期可以从VMEM加载3个寄存器,并向VMEM写入1个寄存器。

VPU:VPU是一个形状为 (8, 128) 的二维向量算术单元,其中128维被称为lane轴,8维被称为sublane轴。在v5上,每个(lane, sublane)对包含4个相互独立的标准浮点ALU。VPU在其每个ALU中用一个周期执行大多数算术指令(如vadd或向量加法),延迟为2个周期,所以例如在v5中,你每个周期可以从VREGs中将4对f32值相加。一个典型的VPU指令可能看起来像 {v2 = vadd.8x128.f32 v0, v1},其中v0和v1是输入VREGs,v2是输出VREG。

所有lane和sublane都以纯粹的SIMD方式每个周期执行相同的程序,但每个ALU可以执行不同的操作。所以我们可以在一个周期内处理1个vadd和1个vsub,每个都作用于两个完整的VREGs并将输出写入第三个。

小测验 [计算VPU吞吐量]:使用以上信息,计算一个TPU v5p可以执行多少向量FLOPs/s。一个TPU v5p的时钟速度约为1.75GHz。

点击此处查看答案。

答案:每个周期,每个核心可以在 8 * 128 个ALU上执行4个向量指令。这给了我们整个芯片每个周期 8 * 128 * 4 * 2 FLOPs,或者 8 * 128 * 4 * 2 * 1.75e9 = 1.4e13 FLOPs/s。注意这比MXU的FLOPs/s(约2e14)小多少(大约10倍)。

归约:通常,跨sublane维度的通信或归约比跨lane维度的要容易。例如,VPU支持一个lane内shuffle操作,可以在大小为8的轴上大约一个周期内滚动。这可以用来在sublane维度上执行高效的归约(只需按4、2和1进行shuffle,并进行3对逐元素求和)。

跨lane归约要困难得多,并且涉及一个称为XLU或“跨lane单元”的独立硬件单元,它又慢又相当昂贵。

与GPU的比较:对于熟悉NVIDIA GPU的人来说,VPU中的每个ALU类似于一个CUDA核心,而一个VPU lane类似于一个“Warp Scheduler”,即通常执行SIMD算术的32个CUDA核心的集合。在lane内的归约相当容易,但如果我们需要跨lane,我们至少需要通过VMEM/XLU/SMEM,这要慢得多。更多细节请参见GPU部分

标量核心

标量核心是TPU的控制单元。它获取并分派所有指令,执行从HBM到VMEM的传输,并且可以编程来做标量元数据工作。由于标量核心是单线程的,这带来的一个副作用是TPU的每个核心每个周期只能创建一个DMA请求。

具体来说,一个标量核心控制着一个VPU(由4096个ALU组成)、4个MXU、2个XLU和多个DMA引擎。这种每个计算单元控制的高度倾斜特性是硬件效率的来源,但也限制了以任何有趣的方式进行数据依赖向量化的能力。

附录B:脉动阵列是如何工作的?

TPU MXU的核心是一个 128x128 脉动阵列(TPU v6e上为 256x256)。当完全饱和时,脉动阵列可以每8个时钟周期执行一次 bfloat16[8,128] @ bf16[128x128] -> f32[8,128]如果你不熟悉这种表示法,它的意思是:将一个`8x128`的bfloat16元素矩阵与一个`128x128`的bfloat16元素矩阵相乘,并将结果存储在一个`8x128`的float32元素矩阵中。 乘法。

这是一个将一组权重(蓝色)与一组激活(绿色)相乘的简化动画。你会注意到权重(RHS)首先被部分加载,呈对角线状,然后激活也被对角线地送入。在下面的每一帧中,我们将所有重叠的绿色和蓝色单元相乘,将结果与从上方传入的任何残差相加,然后将结果依次向下一个单元传递。

这是一个更通用的版本动画,展示了输出如何从计算中流出:

这是一个展示了如何在多个RHS和LHS数组之间进行流水线操作的图表:

在加载权重(RHS)和激活(LHS)时存在一个初始的流水线气泡。在那个初始气泡之后,可以加载新的输入和权重而无需额外的气泡。

这是一个bf16[2, 3] x bf16[3, 3]矩阵乘法的拙劣动画,你可以把它想象成一个2x3的权重矩阵与一个批量为1、大小为3的输入激活的矩阵乘法。这个动画与前面的幻灯片相比是旋转的,输入向右流出而不是向下,但你大致可以看到结构。

我们可以高效地对此进行流水线操作,以乘法大型矩阵而不会产生过大的流水线气泡。话虽如此,重要的是我们的矩阵形状要大于MXU的边长,通常是128x128。一些TPU(自TPU v3起)有多个MXU,TPU v3有2个,TPU v4/5有4个,所以我们需要确保分块维度大于128 * MXU的数量。这里有一个很好的动画。

Trillium(TPU v6e)有一个 256x256 脉动阵列,这意味着它每周期可以执行4倍的FLOPs。这也意味着你的张量维度需要是原来的两倍大才能充分利用MXU。

这篇博客文章有另一个关于固定权重矩阵的脉动阵列乘法的极佳动画。

脚注

  1. TPU v6e (Trillium) 有一个256x256的MXU,而之前所有代际都使用128x128[↩]
  2. TPU,特别是其脉动阵列,之所以是如此强大的硬件加速器,是因为矩阵乘法是少数几个使用 O(n^3) 计算量处理 O(n^2) 字节数据的算法之一。这使得普通的ALU很容易受计算而非内存带宽的限制。[↩]
  3. 我们有时会谈到VMEM预取,指的是提前在VMEM中加载权重,这样我们就可以掩盖矩阵乘法的加载成本。例如,在标准的Transformer中,我们有时可以在注意力计算期间将大的前馈网络权重加载到VMEM中,如果我们的内存带宽受限,这可以隐藏权重加载的成本。这要求我们的权重足够小或分片得足够细,以便能将单层权重放入VMEM并留有余地。[↩]
  4. 在Cloud TPU VM上,每个托盘都作为独立VM的一部分暴露出来,因此再次可见4个核心。[↩]
  5. 光学开关只是一个具有相同ICI带宽的可重构连接。它只是让我们在连接立方体的同时保留环绕链路。[↩]
  6. 请注意,一个 `2x2x4` 不会有任何环绕连接,因为它们是由光开关提供的,而光开关只在完整的立方体上可用。然而,一个TPU v5e 8x16 _将_在较长的轴上有环绕连接,因为它不使用可重构的光网络。[↩]
  7. 上述页面列出的带宽为100 GB/s,与此处列出的略有不同。TPU ICI链路的带宽根据执行的操作略有不同。你通常可以放心使用本文档中的数字。[↩]
  8. TPU v6e为12.5e9字节/秒,v5e为3.125e9字节/秒。[↩]
  9. 我们所说的双向(bidi)带宽是指单个链路上双向可发送的总字节数,或者等同于,假设我们可以高效地使用两个链路,单个TPU沿特定轴的总出向字节数。当我们有一个功能正常的环形网络时,即当我们在特定轴上有环绕连接时,这是成立的。这在推理芯片上当我们有一个完整的16轴时发生,或者在训练芯片(v*p)上当我们有一个是4的倍数的轴时发生。我们更喜欢使用双向带宽,因为它在涉及双向通信的计算中频繁出现。[↩]
  10. 如果你不熟悉这种表示法,它的意思是:将一个`8x128`的bfloat16元素矩阵与一个`128x128`的bfloat16元素矩阵相乘,并将结果存储在一个`8x128`的float32元素矩阵中。[↩]

参考文献

  1. TPU v4:一种用于机器学习的光学可重构超级计算机,具有对嵌入的硬件支持
    Jouppi, N.P., Kurian, G., Li, S., Ma, P., Nagarajan, R., Nai, L., Patil, N., Subramanian, S., Swing, A., Towles, B., Young, C., Zhou, X., Zhou, Z. and Patterson, D., 2023. arXiv [cs.AR].

其他

*工作于Google DeepMind期间完成,现就职于MatX。

引用

在学术场合引用,请按以下格式引用此作品:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

或使用BibTeX条目:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }