微信公众号
《如何扩展你的模型》第6部分 (第5部分:训练 | 第7部分:推理)
让我们仔细研究一下如何利用上一节学到的知识,在 TPU v5p 上训练 LLaMA 3 模型。它们有多大?在不同配置下训练的成本有多高?它们是如何分片的?让我们通过一些粗略的估算,来探讨前面几节的内容如何应用到真实模型上。
本节的目标是将上一节的结论应用到一个非常实际的问题上:训练 LLaMA 3 系列模型。与前面几节不同,我们希望你能亲自动手完成大部分工作。因此,我们隐藏了每个小节的答案,以便你可以先尝试自己解答。试试拿起笔,手动计算一下吧!
LLaMA-3 模型家族
| 超参数 | 值 |
|---|---|
| 80 | |
| 8,192 | |
| 28,672 | |
| 64 | |
| 8 | |
| 128 | |
| 128,256 |
为了凸显找到这些信息有多容易,下面是配置本身及其映射关系:
为许多不同的开源 LLM 制作一个包含这些数字的大表格会很有用,这样你就可以快速比较它们在设计上所做的决策。
问题:根据这个表格,我们能计算出 LLaMA 3-70B 的参数数量吗?🤫 让我们应用第4节的内容,看看能否得到 70B 这个数字!
| 参数 | 公式 | 数量 |
|---|---|---|
| FFW 参数 | d_model * d_ff * 3 (用于 gelu + 输出投影) * n_layers | 8,192 * 8,192 * 3.5 * 3 * 80 = 56.3e9 |
| 词表参数 | 2 (输入和输出嵌入) * n_embeddings * d_model | 2 * 128,256 * 8,192 = 2.1e9 |
| 注意力参数 | n_layers * [ 2 (用于q嵌入和拼接的输出投影) * d_model * n_heads * d_qkv + 2 (用于k和v) * d_model * n_kv_heads * d_qkv] | 80 * (2 * 8,192 * 64 * 128 + 2 * 8,192 * 8 * 128) = 12e9 |
| 56.3e9 + 2.1e9 + 12e9 = 70.4e9 |
太棒了!我们得到了预期的数字。你会注意到,正如预期的那样,FFW 参数在总参数量中占绝对主导地位,尽管注意力机制的参数也不可忽略。
要点:MLP 模块中的3个大型权重矩阵比 Transformer 中的所有其他数组都要大得多,因此在推断模型内存或 FLOPs 时,我们通常几乎可以忽略所有其他参数。对于 LLaMA 3-70B,它们占了70B参数中的56B。
现在我们来看看 FLOPs!记住第4节中关于训练的一般规则。
问题:LLaMA-3 在每个训练步骤中为每个 token 执行多少 FLOPs?这有助于我们确定整个训练过程的成本。
答案:如第4节所示,每个 token 大约需要 6 * 70e9 = 4.2e11 FLOPs / token。这大约是每个 token 每步半个 TFLOP。假设我们受计算限制,且 FLOPs 利用率完美,这在单个 TPU v5p 芯片上大约需要 4.2e11 / 4.59E+14 = 1ms。
问题:LLaMA 3 训练了大约 15 万亿个 token。总共需要多少 FLOPs?
答案:这很简单,就是 4.2e11 * 15e12 = 6.3e24 FLOPs。总计 6.3 yottaFLOPs。这个数字非常大!在单个 TPU 上,这将需要 6.3e24 / 4.59E+14 = 435 年。这也太久了!
问题:假设我们想在一个包含 16x20x28 = 8960 个芯片的完整 TPU v5p pod 上进行训练。假设我们受计算限制,在使用 bfloat16 格式且 MFU 为 40% 的情况下,训练需要多长时间?
答案:我们知道每个 TPU v5p 每秒可以执行 4.59e14 FLOPs。在 40% MFU 的情况下,这大约需要 T = 6.3e24 / (8960 * 4.59e14 * 0.4) = 3.8e6 seconds。这大约是 44 天!这个时间相当合理,前提是我们真的能达到 40% 的 MFU。
问题:LLaMA 3-70B 的预训练批次大小约为 4M token。使用这个批次大小进行训练,我们最少需要多少个 TPU?你可以假设参数为 bfloat16 格式,优化器状态为 float32 格式,并且每层对梯度进行 4 次检查点操作。
答案:这个问题主要是在问内存使用情况,因为这是对可用计算资源的唯一硬性约束。在训练期间,HBM 主要有三个用途:模型参数、优化器状态和梯度检查点。如果我们假设权重为 bfloat16,优化器状态为 float32,并采用一个非常保守的梯度检查点方案(每层 4 次),我们得到:
| 参数 | 2 * 70GB | ~140GB |
| 优化器状态 | 8 * 70GB | ~560GB |
| 梯度检查点 | 2 * 8192 * 4e6 * 4 * 80 | ~20.9TB |
| 总计 | ~21.6TB |
这里的总内存约为 21.6TB。你会注意到,即使采用了非常保守的检查点方案,梯度检查点仍在内存使用中占主导地位。理论上,我们可以每层只设置 1 个检查点,或者使用微批处理,但这已经是一个合理的估算了。基于这些假设,由于每个 TPU v5p 有 96GB 的 HBM,我们需要 21.6e12 / 96e9 = 225 个 TPU。实际上这并不算多!
我们为什么不这样做呢?嗯,因为这会花费我们 44 days * 8960 / 225 = 1752 days 的时间来训练。这差不多是四年。时间太长了。尽管如此,这清楚地表明,我们使用这些大型集群并非因为受到内存限制,而是因为我们需要额外的 FLOPs。
问题:在与上一个问题相同的假设下,如果我们使用 8960 个 TPU v5p 芯片,每个芯片将使用多少内存?
答案:我们的总内存仍然是大约 21.6TB,所以每个芯片大约使用 2.4GB,这基本上可以忽略不计。如果我们采用更激进的检查点策略,例如每层 12 个检查点,每个芯片也只占用 8GB。在这种规模的训练中,我们远未达到内存瓶颈。
要点:从技术上讲,即使在非常小的拓扑结构上训练非常大的模型也是可能的,但需要注意的是,这可能会花费很长时间。能够计算出一次训练运行的总 FLOPs,使我们能够通过假设一个适度的 MFU 和已知的拓扑结构来粗略估计其训练时间。
让我们继续沿用上面的设定,假设我们想在一个包含 8960 个芯片的 TPU v5p pod 上,以 4M token 的批次大小(每批次 1024 个序列,每个序列长度为 4096)来训练 LLaMA 3-70B。让我们来讨论一下针对这个模型的最佳分片策略。
问题:在上述假设下,我们能否仅使用 FSDP 来训练我们的模型?首先,我们假设不能进行任何序列/上下文并行。这应该是你首先想到的方案,因为它很简单,并且如果可行的话,不会引入额外的通信开销。
答案:这个答案可能有点迂腐。如上所述,LLaMA 3-70B 最初是用长度为 4K 的序列进行训练的,所以 4M token 的批次大小给了我们 1024 的序列批次大小。这意味着我们最多只能在 1024 个芯片上进行纯数据并行/FSDP,因为我们只有这么多序列可以用来做数据并行。所以,从“完全数据并行且无额外通信”这个简单意义上来说,答案是否定的。下一个问题将回答一个不那么迂腐的版本。
问题:让我们放宽不进行任何序列分片的要求。如果我们允许自己同时在批次和序列维度上进行 FSDP,我们能否仅用 FSDP 在 8960 个芯片上训练 LLaMA 3-70B?
答案:现在我们允许自己也进行序列/上下文并行,这样就可以扩展到更大的规模。首先,让我们计算每个设备的批次大小。如果我们进行 8960 路 FSDP,每个 TPU 的批次大小最终为 4 * 1024 * 1024 / 8960 = 468 tokens。从上一节我们知道,当
问题:现在让我们看看混合张量并行和 FSDP。是否存在某种组合,能让我们保持受计算限制?如果存在,我们应该使用多大程度的 FSDP 和张量并行?
答案:首先,让我们检查一下这是否可行。我们知道,如果每个芯片的批次大小小于
四舍五入到一个合理的 2 的倍数,这给了我们大约 2048 路 FSDP 和 4 路模型并行。这应该能很好地工作!
要点:我们可以在一个完整的 TPU v5p pod 上,通过混合使用数据并行(1024 路)、序列并行(2 路)和张量并行(4 路)的方式,以 4M token 的批次大小训练 LLaMA-3,而不会受通信限制。如果我们尝试使用纯 FSDP 或 FSDP + 序列并行,则会受通信限制。我们在上一节中推导出的方程非常实用。
问题1 [将 LLaMA 70B 扩展到更多芯片]:假设我们想在 4 个 pod 上以相同的批次大小训练 LLaMA 3-70B。我们会使用哪种并行方案?我们会受计算限制还是通信限制?训练大约需要多长时间?请确保使用正确的Roofline边界。
问题2 [LLaMA 405B]:
(a) 使用 LLaMA 3-405B 的配置,像上面一样写一个包含所有关键超参数的表格。这个模型总共有多少参数?每个训练步骤需要多少 FLOPs?如果我们用 15T token 进行训练,总共会执行多少 FLOPs?
(b) 假设我们想在 8 个 TPU v5p pod 上进行训练。我们会使用哪种并行方案?训练需要多长时间?会受计算限制还是通信限制?