微信公众号
《如何扩展你的模型》第4部分 (第3部分:分片 | 第5部分:训练)
在这里,我们将快速回顾Transformer架构,特别是如何计算FLOPs、字节数和其他感兴趣的量。
让我们从向量\(x\)、\(y\)和矩阵\(A\)、\(B\)开始,它们的形状如下:
\[\def \red#1{\textcolor{red}{#1}} \def \green#1{\textcolor{green}{#1}} \def \blue#1{\textcolor{blue}{#1}} \def \purple#1{\textcolor{purple}{#1}} \def \orange#1{\textcolor{orange}{#1}} \def \gray#1{\textcolor{gray}{#1}} \begin{array}{cc} \textrm{数组} & \textrm{形状} \\ \hline x & \textrm{[P]} \\ y & \textrm{[P]} \\ A & \textrm{[N P]} \\ B & \textrm{[P M]} \\ \hline \end {array}\]请注意,对于矩阵-矩阵乘法,计算量以立方级(\(O(N^3)\))增长,而数据传输仅以平方级(\(O(N^2)\))增长——这意味着随着我们扩大矩阵乘法的规模,达到计算饱和的极限变得更容易。这是极不寻常的,并且在很大程度上解释了为什么我们使用以矩阵乘法为主的架构——它们易于扩展!
在训练期间,我们并不特别关心给定矩阵乘法的结果;我们真正关心的是它的导数。这意味着我们在反向传播期间会进行更多的FLOPs。
如果我们假设B只是一个更大网络中的一个矩阵,而A是我们的输入激活,其中C = A B,则损失L相对于B的导数由链式法则给出:
\[\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial B} = A^T \left(\frac{\partial L}{\partial C}\right)\]这是一个外积,需要
同样是
将这些加起来,我们看到在训练期间,我们总共有6NPM FLOPs,而在推理期间是2NPM:前向传播2NPM,反向传播4NPM。由于PM是矩阵中的参数数量,这是著名的\(6 * \text{参数数量} * \text{token数量}\)近似Transformer训练期间FLOPs的最简单形式:每个token需要\(6 * \text{参数数量}\) FLOPs。我们将在下面展示一个更正确的推导。
Transformer是未来。嗯,至少它们是现在。也许几年前,它们是众多架构之一。但今天,几乎值得了解该架构的每一个细节。我们不会重新介绍该架构,但这篇博客和原始Transformer论文可能会是很有帮助的参考资料。
以下是Transformer解码器架构的基本示意图:
注意 [门控einsum]:上图使用了一种“门控einsum”
注意2 [MHA注意力]:对于自注意力,T和S是相同的,但对于交叉注意力,它们可能不同。对于普通的多头注意力(MHA),N和K是相同的,而对于多查询注意力(MQA)
为了避免到处都写上L因子,下面我们将计算每层的FLOPs。
Transformer的MLP通常由2个输入矩阵乘法(其结果逐元素组合)和1个输出矩阵乘法组成:
\[\begin{array}{ccc} \textrm{操作} & \textrm{训练 FLOPs} & \textrm{参数} \\ \hline \\ A[B,T,\red{D}] \cdot W_{in1}[\red{D}, F] & 6BTDF & DF \\[10pt] A[B,T,\red{D}] \cdot W_{in2}[\red{D}, F] & 6BTDF & DF \\[10pt] \sigma\left(A_{in1}\right)[B,T, F] * A_{in2}[B,T, F] & \gray{O(BTF)} \\[10pt] A[B,T,\red{F}] \cdot W_{out}[\red{F}, D] & 6BTDF & DF \\[10pt] \hline \\ & \approx 18BTDF & 3DF \end{array}\]对于具有不同Q和KV头数的通用分组查询注意力情况,我们假设Q、K、V投影的头维度H相等,并估计QKVO矩阵乘法的成本:
\[\begin{array}{ccc} \textrm{操作} & \textrm{训练 FLOPs} & \textrm{参数} \\ \hline \\ A[B,T,\red{D}] \cdot W_{Q}[\red{D}, N, H] & 6BTDNH & DNH \\[10pt] A[B,T,\red{D}] \cdot W_{K}[\red{D}, K, H] & 6BTDKH & DKH \\[10pt] A[B,T,\red{D}] \cdot W_{V}[\red{D}, K, H] & 6BTDKH & DKH \\[10pt] A[B,T,\red{N}, \red{H}] \cdot W_{O}[\red{N}, \red{H}, D] & 6BTDNH & DNH \\[10pt] \hline \\ & 12BTD(N+K)H & 2D(N+K)H \end{array}\]点积注意力操作更为精细,实际上是在\(B\)、\(K\)维度上批处理的\(TH \cdot HS\)矩阵乘法,一个softmax,以及再次在\(B\)、\(K\)维度上批处理的\(TS \cdot SH\)矩阵乘法。我们用蓝色突出显示批处理维度:
\[\begin{array}{cc} \textrm{操作} & \textrm{训练 FLOPs} \\ \hline \\[3pt] Q[\blue{B}, T, \blue{K}, G, \red{H}] \cdot K[\blue{B}, S, \blue{K}, \red{H}] & 6BTSKGH = 6BTSNH \\[3pt] \textrm{softmax}_S \;\; L[B, T, S, K, G] & \gray{O(BTSKG) = O(BTSN)} \\[3pt] S[\blue{B}, T, \red{S}, \blue{K}, G] \cdot V[\blue{B}, \red{S}, \blue{K}, H] & 6BTSKGH = 6BTSNH \\[3pt] \hline \\ & \approx 12BTSNH = 12BT^2NH \\ \end{array}\]在Transformer中还有其他几种操作。层归一化(Layernorm)相对便宜,在一阶成本估算中可以忽略。还有一个最终的巨大(但不是每层都有)的unembedding矩阵乘法。
\[\begin{array}{ccc} \textsf{操作} & \textsf{训练 FLOPs} & \textsf{参数} \\ \hline \\ \textrm{layernorm}_D \;\; A[B,T,\red{D}] & \gray{O\left(BTD\right)} & \gray{D} \\[10pt] A[B,T,\red{D}] \cdot W_{unembed}[\red{D}, V] & 6BTDV & DV \\ \end{array}\]如果我们忽略短上下文训练中点积注意力的成本,那么所有层的总FLOPs为
\[\begin{align*} (18BTDF + 12BTD(N+K)H)L = 6 *BT * (3DF + 2D(N+K)H)L \\ = 6 * \textrm{token数量} * \textrm{参数数量} \end{align*}\]这就得出了一个著名的经验法则,用于估算密集Transformer的FLOPs数量,忽略了注意力的FLOPs。(Unembedding是另一个简单的矩阵乘法,有
如果我们确实考虑了上面的点积注意力,并假设\(F=4D\),\(D=NH\)(通常如此)和\(N=K\):
\[\small{\frac{\textrm{注意力 FLOPs}}{\textrm{矩阵乘法 FLOPs}} = \frac{12BT^2NH}{18BTDF + 24BTDNH} = \frac{12BT^2D}{4*18 BTD^2 + 24 BTD^2} = \frac{12BT^2D}{96 BTD^2} = \frac{T}{8D}}\]因此,结论是点积注意力的FLOPs仅在训练期间T>8D时才占主导地位。对于D约等于8k,这将是约64K个token。这在某种程度上是合理的,因为它意味着随着MLP大小的增加,注意力的FLOPs变得不那么关键。对于大型模型,注意力的二次成本实际上并不是长上下文训练的巨大障碍。然而,对于较小的模型,例如Gemma-27B,D=4608,这意味着注意力在序列长度约32k时开始占主导地位。Flash Attention也有助于减轻长上下文的成本,我们将在附录A中简要讨论。
我们不能不简要讨论一下专家混合(MoE)模型
与密集模型相比,MoE引入了新的通信,主要是两个AllToAll操作(一个在MoE块之前,一个在之后),用于将token路由到正确的专家,并将它们带回其宿主设备。
反向传播作为一种算法,是用内存换取计算。反向传播不再需要\(O(n_\text{layers}^2)\)的FLOPs,而是需要\(O(n_\text{layers})\)的内存,保存前向传播期间生成的所有中间激活。虽然这比二次计算要好,但在内存方面却非常昂贵:一个拥有\(B * T=4M\)(每批次总共4M个token)、L=64和D=8192的模型,如果避免所有不必要的反向传播计算,将不得不以bfloat16格式保存大约\(2 * 20 * B * T * D * L = 84TB\)的激活。这里的20(大致)来自于计算上图Transformer图中的每个中间节点,因为例如
\[f(x) = \exp(g(x))\] \[\frac{df}{dx} = \exp(g(x)) \cdot \frac{dg}{dx}\]因此为了避免重新计算,我们需要保存前向传播中的\(g(x)\)和\(\exp(g(x))\)。为了避免保存这么多内存,我们可以选择只保存一部分中间激活。以下是我们使用的几种策略。
这绝不是全面的。在使用JAX时,这些通常由jax.remat/jax.checkpoint控制(你可以在这里阅读更多)。
正如我们将在第7节中看到的,LLM推理有两个关键部分,预填充(prefill)和生成(generation)。
每个KV缓存实际上是一个大小为
| 组件 | 每层参数 | 每层训练FLOPs |
|---|---|---|
| MLP | 3DF | 18BTDF |
| 注意力 | 4DNH | 24BTDNH + 12BT2NH |
| 其他 | D | BTD |
| 词汇表 | DV(总计,非每层) | 12BTDV |
问题1: 一个模型,其
512kB / token。问题2: 在{'X': 4, 'Y': 8, 'Z': 4}上执行 A[BX, DY] *D W[DY, F] 需要多少总FLOPs?每个TPU执行多少FLOPs?
该操作的总“理论”FLOPs是 \(2 \cdot B \cdot D \cdot F\)。然而,因为计算没有在Z维度上分片,我们实际上多做了Z倍的FLOPs,意味着总FLOPs为 \(2 \cdot B \cdot D \cdot F \cdot Z\)。由于计算在其他维度上是分片的,每个设备的总FLOPs大约是 \(2 \cdot B \cdot D \cdot F / (X \cdot Y)\)。
问题3: 执行
遵循上述规则,我们有I和J作为收缩维度,K、L、M、N和O作为非收缩维度。我们没有“批处理维度”,所以这只是 \(2 \cdot I \cdot J \cdot K \cdot L \cdot M \cdot N \cdot O\),即所有轴的乘积之和。如果我们有一个共享轴,它只会被计算一次。
问题4: 自注意力的算术强度是多少(忽略Q/K/V/O投影)?以Q和KV长度T和S的函数形式给出答案。 在什么上下文长度下,注意力是FLOPs受限的?给定我们TPU的HBM带宽,绘制随着上下文长度增长,注意力与FFW块的有效相对成本图。
自注意力需要加载\(Q\)、\(K\)和\(V\)激活,然后计算\(\text{softmax}(Q \cdot K) \cdot V\),然后将结果写回HBM。这将使用Flash Attention完成,所以这个数学计算有一些注意事项,但基本上在bf16中,自注意力执行
\[\text{Q[B,T,N,H]} \rightarrow_\text{reshape} \text{Q[B, T, K, G, H]} \cdot \text{K[B, S, K, H]} \rightarrow \text{O[B, T, S, K, G]}\] \[U=\text{softmax}_S(\text{O[B, T, S, K, G]}) \] \[\text{U[B, T, S, K, G]} \cdot \text{V[B, S, K, H]} \rightarrow \text{X[B, T, K, G, H]}\]所以我们的总字节数是 \(2 * \text{sizeof}(Q) + 2 * \text{sizeof(K or V)} = 4BTNH + 4BSKH = 4BHK * (TG + S)\),总FLOPs是 \(4BTSNH + O(BTSN)\),算术强度是 \(4BTSKGH / (4BHK * (TG + S))\)。
所以基本上,在预填充期间,我们有\(S=T\),所以我们的算术强度是\(4BT^2KGH / 4BHKT \cdot (G+1) = TG/(G + 1) = O(T)\)。在生成期间,\(T=1\),所以我们有\(4BSKGH / (4BHK \cdot (G + S)) = SG / (G + S) \rightarrow G\),假设\(S\)非常大。根据你如何解释这个问题,在预填充或训练期间,假设没有序列分片,自注意力在S=240时是计算受限的。在生成期间,我们永远不会是计算受限的,因为\(G\)很小。然而,你可以看到,增加\(G\)会使我们更接近计算受限。
问题5: 在什么序列长度下,自注意力的FLOPs等于QKVO投影的FLOPs?
这纯粹是关于何时\(24BTDNH == 12BT^2NH\)的问题。简化后我们得到\(2D = T\),例如对于\(D=4096\),这是\(8192\)。这告诉我们,对于大多数合理的上下文长度,矩阵乘法的FLOPs更大。
问题6: 假设我们在前向传播过程中只保存Transformer层中7个主要矩阵乘法(Q, K, V, O + 三个FFW矩阵)的输出。在反向传播过程中,我们需要“重新物化”多少额外的FLOPs?
只保存七个矩阵乘法的输出(Q, K, V, O, W₁, W₂, W₃)意味着反向传播必须重新计算两个注意力矩阵乘法
\[QK^{\top} \quad\text{和}\quad \operatorname{softmax}(QK^{\top})V.\]每个都是在\(B\)个序列和\(N\)个头上批处理的\(T \times T\)矩阵乘法,所以额外的FLOPs是
\[4 \; B \, T^{2} \, N \, H.\]所有其他重新计算的操作都只是\(O(BTD)\)。
问题7: DeepSeek v3声称它在14.8T个token上训练了279万H800小时(来源)。鉴于它有37B个激活参数,他们大致达到了什么样的硬件利用率?提示:注意他们使用了没有结构化稀疏性的FP8 FLOPs。
从规格表这里,我们发现带稀疏性的FP8性能为3,026 TFLOPs/s,或者通常不带稀疏性时是这个值的一半(1.513e15 FLOPs/s)。279万H800小时意味着2.79e6 * 1.513e15 * 60 * 60 = 1.52e25总FLOPs。鉴于激活参数数量为37B,这次训练运行应该使用了大约6 * 37e9 * 14.8e12 = 3.3e24 FLOPs。这意味着FLOPs利用率大约是3.3e24 / 1.52e25 = 21.7%。
问题8: 专家混合(MoE)模型有\(E\)个标准密集MLP块的副本,每个token激活其中\(k\)个专家。在TPU v5e上,对于权重为int8的MoE,需要多大的批大小(以token为单位)才能达到计算受限?对于有256个(路由)专家且\(k=8\)的DeepSeek,这个数字是多少?
因为我们有\(E\)个每个专家的副本,在int8中,我们需要加载\(E \cdot D \cdot F\)字节。因为每个token激活\(k\)个专家,我们有\(2\cdot k \cdot B \cdot D \cdot F\) FLOPs。要使用bfloat16 FLOPs达到计算受限,我们需要算术强度超过240,这发生在\((2\cdot k \cdot BDF) / EDF > 240\)或\(k \cdot B / E > 120\)时。
因此,我们需要\(B > 120 \cdot E / k\)才能达到计算受限。对于DeepSeek,这给我们\(B > 120 \cdot 256 / 8 = 3840\)。这在生成时是一个非常大的批大小。
将Transformer扩展到非常长的上下文的传统反对意见是,注意力的FLOPs和内存使用量随上下文长度呈二次方增长。虽然注意力QK乘积的形状确实是
这第二个观察首先由Rabe等人于2021年提出,后来在Flash Attention论文(Dao等人,2022年)中再次提出。基本思想是将K/V分块计算注意力,我们计算局部的softmax和一些辅助统计数据,然后将它们传递给下一个块,后者将其与自己的局部分块结合起来。具体来说,我们计算
有了这些,我们只需恒定的内存量就可以计算出新的最大值、新的运行总和和新的输出。为了粗略地描述这是如何工作的,注意力大致是这个操作:
\[\text{Attn}(Q, K, V) = \sum_i \frac{\exp(Q \cdot K_i - \max_j Q \cdot K_j) V_i}{\sum_l \exp(Q \cdot K_l - \max_j Q \cdot K_j)}\]减去最大值是为了数值稳定性,并且可以在不影响结果的情况下添加,因为\(\sum_i \exp(a_i + b) = \exp(b) \sum \exp(a)\)。只看上面的分母,如果我们想象有两个连续的键向量块,\(K^1\)和\(K^2\),并且我们为每个块计算局部softmax和\(L^1\)和\(L^2\)
\[L^1 = \sum_i \exp(Q \cdot K_i^1 - \max_j Q \cdot K_j^1)\] \[L^2 = \sum_i \exp(Q \cdot K_i^2 - \max_j Q \cdot K_j^2)\]然后我们可以通过使用以下公式将它们组合成这两个块的完整softmax和
\[L^\text{combined} = \exp(M^1 - \max(M^1, M^2)) \cdot L^1 + \exp(M^2 - \max(M^1, M^2)) \cdot L^2\]其中
\[M^1 = \max_j Q \cdot K_j^1 \text{ and } M^2 = \max_j Q \cdot K_j^2\]这也可以对整个softmax进行,为我们提供了一种累积任意大softmax和的方法。以下是Flash Attention论文中的完整算法。
从硬件的角度来看,这使我们能够将Q的块放入VMEM(上述算法称之为片上SRAM),因此我们只需在每次迭代时加载KV块,从而降低了算术强度。我们也可以将运行统计信息保存在VMEM中。
最后一个值得强调的微妙点是注意力softmax的一个属性,它被用来使Flash VJP(反向模式导数)计算在训练中变得可行。如果我们定义一个中间softmax数组为:
\[S_{ij} = \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_j}}\]在注意力中,我们从反向模式的dO和V数组中获得dS:
\[dS_{ij} = dO_{id} \cdot_d V_{jd} = \sum_d dO_{id} V_{jd}\]在将此梯度反向传播到Q和K的过程中
\[d(q_i \cdot k_j) = (dS_{ij} - S_{ij} \cdot_j dS_{ij}) S_{ij}\]我们利用一个恒等式,它允许我们将沿大键长度维度的收缩与沿特征深度维度的局部收缩进行交换。
\[\begin{align*} S_{ij} \cdot_j dS_{ij} &= \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} \sum_d dO_{id} V_{jd} \\ &= \sum_d dO_{id} \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} V_{jd} \\ &= \sum_d dO_{id} O_{id} \\ &= dO_{id} \cdot_d O_{id} \end{align*}\]这种替换对于能够为VJP实现序列块局部计算至关重要,并使得更智能的分片方案(如环形注意力)成为可能。