🔗 英文原文: https://jax-ml.github.io/scaling-book/applied-training/
✍️ 翻译: 北极的树
微信二维码 微信公众号

在 TPU 上训练 LLaMA 3

《如何扩展你的模型》第6部分 (第5部分:训练 | 第7部分:推理)

让我们仔细研究一下如何利用上一节学到的知识,在 TPU v5p 上训练 LLaMA 3 模型。它们有多大?在不同配置下训练的成本有多高?它们是如何分片的?让我们通过一些粗略的估算,来探讨前面几节的内容如何应用到真实模型上。

本节的目标是将上一节的结论应用到一个非常实际的问题上:训练 LLaMA 3 系列模型。与前面几节不同,我们希望你能亲自动手完成大部分工作。因此,我们隐藏了每个小节的答案,以便你可以先尝试自己解答。试试拿起笔,手动计算一下吧!

LLaMA 3 的架构是怎样的?

LLaMA-3 模型家族包含3个主要模型:LLaMA 3 8B、70B 和 405B。我们将主要关注 70B 模型,并将 8B 和 405B 模型留到最后的练习题部分供你探索。以下是 LLaMA 3-70B 的架构,摘自 LLaMA 的 HuggingFace 页面

超参数
nlayers (L) 80
dmodel (D) 8,192
dff (F) 28,672
nheads (N) 64
nkv_heads (K) 8
dqkv (H) 128
nembeddings (V) 128,256

为了凸显找到这些信息有多容易,下面是配置本身及其映射关系:

为许多不同的开源 LLM 制作一个包含这些数字的大表格会很有用,这样你就可以快速比较它们在设计上所做的决策。

计算参数量和 FLOPs

问题:根据这个表格,我们能计算出 LLaMA 3-70B 的参数数量吗?🤫 让我们应用第4节的内容,看看能否得到 70B 这个数字!

参数 公式 数量
FFW 参数 d_model * d_ff * 3 (用于 gelu + 输出投影) * n_layers 8,192 * 8,192 * 3.5 * 3 * 80 = 56.3e9
词表参数 2 (输入和输出嵌入) * n_embeddings * d_model 2 * 128,256 * 8,192 = 2.1e9
注意力参数 n_layers * [ 2 (用于q嵌入和拼接的输出投影) * d_model * n_heads * d_qkv + 2 (用于k和v) * d_model * n_kv_heads * d_qkv] 80 * (2 * 8,192 * 64 * 128 + 2 * 8,192 * 8 * 128) = 12e9
56.3e9 + 2.1e9 + 12e9 = 70.4e9

太棒了!我们得到了预期的数字。你会注意到,正如预期的那样,FFW 参数在总参数量中占绝对主导地位,尽管注意力机制的参数也不可忽略。

要点:MLP 模块中的3个大型权重矩阵比 Transformer 中的所有其他数组都要大得多,因此在推断模型内存或 FLOPs 时,我们通常几乎可以忽略所有其他参数。对于 LLaMA 3-70B,它们占了70B参数中的56B。

现在我们来看看 FLOPs!记住第4节中关于训练的一般规则。

问题:LLaMA-3 在每个训练步骤中为每个 token 执行多少 FLOPs?这有助于我们确定整个训练过程的成本。

思考后,点击此处查看答案!

答案:如第4节所示,每个 token 大约需要 6param count FLOPs,所以这里大约是 6 * 70e9 = 4.2e11 FLOPs / token。这大约是每个 token 每步半个 TFLOP。假设我们受计算限制,且 FLOPs 利用率完美,这在单个 TPU v5p 芯片上大约需要 4.2e11 / 4.59E+14 = 1ms

问题:LLaMA 3 训练了大约 15 万亿个 token。总共需要多少 FLOPs?

思考后,点击此处查看答案!

答案:这很简单,就是 4.2e11 * 15e12 = 6.3e24 FLOPs。总计 6.3 yottaFLOPs。这个数字非常大!在单个 TPU 上,这将需要 6.3e24 / 4.59E+14 = 435 年。这也太久了!

问题:假设我们想在一个包含 16x20x28 = 8960 个芯片的完整 TPU v5p pod 上进行训练。假设我们受计算限制,在使用 bfloat16 格式且 MFU 为 40% 的情况下,训练需要多长时间?

思考后,点击此处查看答案!

答案:我们知道每个 TPU v5p 每秒可以执行 4.59e14 FLOPs。在 40% MFU 的情况下,这大约需要 T = 6.3e24 / (8960 * 4.59e14 * 0.4) = 3.8e6 seconds这大约是 44 天!这个时间相当合理,前提是我们真的能达到 40% 的 MFU。

问题:LLaMA 3-70B 的预训练批次大小约为 4M token。使用这个批次大小进行训练,我们最少需要多少个 TPU?你可以假设参数为 bfloat16 格式,优化器状态为 float32 格式,并且每层对梯度进行 4 次检查点操作。

思考后,点击此处查看答案!

答案:这个问题主要是在问内存使用情况,因为这是对可用计算资源的唯一硬性约束。在训练期间,HBM 主要有三个用途:模型参数、优化器状态和梯度检查点。如果我们假设权重为 bfloat16,优化器状态为 float32,并采用一个非常保守的梯度检查点方案(每层 4 次),我们得到:

参数 2 * 70GB ~140GB
优化器状态 8 * 70GB ~560GB
梯度检查点 2 * 8192 * 4e6 * 4 * 80 ~20.9TB
总计 ~21.6TB

这里的总内存约为 21.6TB。你会注意到,即使采用了非常保守的检查点方案,梯度检查点仍在内存使用中占主导地位。理论上,我们可以每层只设置 1 个检查点,或者使用微批处理,但这已经是一个合理的估算了。基于这些假设,由于每个 TPU v5p 有 96GB 的 HBM,我们需要 21.6e12 / 96e9 = 225 个 TPU。实际上这并不算多!

我们为什么不这样做呢?嗯,因为这会花费我们 44 days * 8960 / 225 = 1752 days 的时间来训练。这差不多是四年。时间太长了。尽管如此,这清楚地表明,我们使用这些大型集群并非因为受到内存限制,而是因为我们需要额外的 FLOPs。

问题:在与上一个问题相同的假设下,如果我们使用 8960 个 TPU v5p 芯片,每个芯片将使用多少内存?

思考后,点击此处查看答案!

答案:我们的总内存仍然是大约 21.6TB,所以每个芯片大约使用 2.4GB,这基本上可以忽略不计。如果我们采用更激进的检查点策略,例如每层 12 个检查点,每个芯片也只占用 8GB。在这种规模的训练中,我们远未达到内存瓶颈。

要点:从技术上讲,即使在非常小的拓扑结构上训练非常大的模型也是可能的,但需要注意的是,这可能会花费很长时间。能够计算出一次训练运行的总 FLOPs,使我们能够通过假设一个适度的 MFU 和已知的拓扑结构来粗略估计其训练时间。

如何为 LLaMA 3-70B 的训练进行分片

让我们继续沿用上面的设定,假设我们想在一个包含 8960 个芯片的 TPU v5p pod 上,以 4M token 的批次大小(每批次 1024 个序列,每个序列长度为 4096)来训练 LLaMA 3-70B。让我们来讨论一下针对这个模型的最佳分片策略。

问题:在上述假设下,我们能否仅使用 FSDP 来训练我们的模型?首先,我们假设不能进行任何序列/上下文并行。这应该是你首先想到的方案,因为它很简单,并且如果可行的话,不会引入额外的通信开销。

思考后,点击此处查看答案!

答案:这个答案可能有点迂腐。如上所述,LLaMA 3-70B 最初是用长度为 4K 的序列进行训练的,所以 4M token 的批次大小给了我们 1024 的序列批次大小。这意味着我们最多只能在 1024 个芯片上进行纯数据并行/FSDP,因为我们只有这么多序列可以用来做数据并行。所以,从“完全数据并行且无额外通信”这个简单意义上来说,答案是否定的。下一个问题将回答一个不那么迂腐的版本。

问题:让我们放宽不进行任何序列分片的要求。如果我们允许自己同时在批次序列维度上进行 FSDP,我们能否仅用 FSDP 在 8960 个芯片上训练 LLaMA 3-70B?

思考后,点击此处查看答案!

答案:现在我们允许自己也进行序列/上下文并行,这样就可以扩展到更大的规模。首先,让我们计算每个设备的批次大小。如果我们进行 8960 路 FSDP,每个 TPU 的批次大小最终为 4 * 1024 * 1024 / 8960 = 468 tokens。从上一节我们知道,当 per device batch size<2550/MX 时,FSDP 会受到 ICI 的限制。由于我们可以在一个完整的 3D pod 中使用 3 个轴,这将给我们一个 850 的下限,而我们远低于这个值。所以答案是否定的,即使有 3 个轴也不行。我们将完全受限于通信。

问题:现在让我们看看混合张量并行和 FSDP。是否存在某种组合,能让我们保持受计算限制?如果存在,我们应该使用多大程度的 FSDP 和张量并行?

思考后,点击此处查看答案!

答案:首先,让我们检查一下这是否可行。我们知道,如果每个芯片的批次大小小于 2550^2 / 2F = 113,我们就会受通信限制。正如我们上面看到的,我们的值略高于此。这太好了!现在,为了选择最佳的 FSDP 数量,我们可以使用公式

Xopt=2BNF=24.19e6896028672=1618

四舍五入到一个合理的 2 的倍数,这给了我们大约 2048 路 FSDP 和 4 路模型并行。这应该能很好地工作!

要点:我们可以在一个完整的 TPU v5p pod 上,通过混合使用数据并行(1024 路)、序列并行(2 路)和张量并行(4 路)的方式,以 4M token 的批次大小训练 LLaMA-3,而不会受通信限制。如果我们尝试使用纯 FSDP 或 FSDP + 序列并行,则会受通信限制。我们在上一节中推导出的方程非常实用。

练习题

问题1 [将 LLaMA 70B 扩展到更多芯片]:假设我们想在 4 个 pod 上以相同的批次大小训练 LLaMA 3-70B。我们会使用哪种并行方案?我们会受计算限制还是通信限制?训练大约需要多长时间?请确保使用正确的Roofline边界。

问题2 [LLaMA 405B]:

(a) 使用 LLaMA 3-405B 的配置,像上面一样写一个包含所有关键超参数的表格。这个模型总共有多少参数?每个训练步骤需要多少 FLOPs?如果我们用 15T token 进行训练,总共会执行多少 FLOPs?

(b) 假设我们想在 8 个 TPU v5p pod 上进行训练。我们会使用哪种并行方案?训练需要多长时间?会受计算限制还是通信限制?

第6节到此结束。关于第7节 Transformer 推理的内容,请点击这里

脚注

    参考文献

    1. The Llama 3 herd of models
      Grattafiori, A., Dubey, A., Jauhri, A., Pandey, A., Kadian, A., Al-Dahle, A., Letman, A., Mathur, A., Schelten, A., Vaughan, A., Yang, A., Fan, A., Goyal, A., Hartshorn, A., Yang, A., Mitra, A., Sravankumar, A., Korenev, A., Hinsvark, A., Rao, A., Zhang, A., Rodriguez, A., Gregerson, A., Spataru, A., Roziere, B., Biron, B., Tang, B., Chern, B., Caucheteux, C., Nayak, C., Bi, C., Marra, C., McConnell, C., Keller, C., Touret, C., Wu, C., Wong, C., Ferrer, C.C., Nikolaidis, C., Allonsius, D., Song, D., Pintz, D., Livshits, D., Wyatt, D., Esiobu, D., Choudhary, D., Mahajan, D., Garcia-Olano, D., Perino, D., Hupkes, D., Lakomkin, E., AlBadawy, E., Lobanova, E., Dinan, E., Smith, E.M., Radenovic, F., Guzmán, F., Zhang, F., Synnaeve, G., Lee, G., Anderson, G.L., Thattai, G., Nail, G., Mialon, G., Pang, G., Cucurell, G., Nguyen, H., Korevaar, H., Xu, H., Touvron, H., Zarov, I., Ibarra, I.A., Kloumann, I., Misra, I., Evtimov, I., Zhang, J., Copet, J., Lee, J., Geffert, J., Vranes, J., Park, J., Mahadeokar, J., Shah, J., van der Linde, J., Billock, J., Hong, J., Lee, J., Fu, J., Chi, J., Huang, J., Liu, J., Wang, J., Yu, J., Bitton, J., Spisak, J., Park, J., Rocca, J., Johnstun, J., Saxe, J., Jia, J., Alwala, K.V., Prasad, K., Upasani, K., Plawiak, K., Li, K., Heafield, K., Stone, K., El-Arini, K., Iyer, K., Malik, K., Chiu, K., Bhalla, K., Lakhotia, K., Rantala-Yeary, L., van der Maaten, L., Chen, L., Tan, L., Jenkins, L., Martin, L., Madaan, L., Malo, L., Blecher, L., Landzaat, L., de Oliveira, L., Muzzi, M., Pasupuleti, M., Singh, M., Paluri, M., Kardas, M., Tsimpoukelli, M., Oldham, M., Rita, M., Pavlova, M., Kambadur, M., Lewis, M., Si, M., Singh, M.K., Hassan, M., Goyal, N., Torabi, N., Bashlykov, N., Bogoychev, N., Chatterji, N., Zhang, N., Duchenne, O., Çelebi, O., Alrassy, P., Zhang, P., Li, P., Vasic, P., Weng, P., Bhargava, P., Dubal, P., Krishnan, P., Koura, P.S., Xu, P., He, Q., Dong, Q., Srinivasan, R., Ganapathy, R., Calderer, R., Cabral, R.S., Stojnic, R., Raileanu, R., Maheswari, R., Girdhar, R., Patel, R., Sauvestre, R., Polidoro, R., Sumbaly, R., Taylor, R., Silva, R., Hou, R., Wang, R., Hosseini, S., Chennabasappa, S., Singh, S., Bell, S., Kim, S.S., Edunov, S., Nie, S., Narang, S., Raparthy, S., Shen, S., Wan, S., Bhosale, S., Zhang, S., Vandenhende, S., Batra, S., Whitman, S., Sootla, S., Collot, S., Gururangan, S., Borodinsky, S., Herman, T., Fowler, T., Sheasha, T., Georgiou, T., Scialom, T., Speckbacher, T., Mihaylov, T., Xiao, T., Karn, U., Goswami, V., Gupta, V., Ramanathan, V., Kerkez, V., Gonguet, V., Do, V., Vogeti, V., Albiero, V., Petrovic, V., Chu, W., Xiong, W., Fu, W., Meers, W., Martinet, X., Wang, X., Wang, X., Tan, X.E., Xia, X., Xie, X., Jia, X., Wang, X., Goldschlag, Y., Gaur, Y., Babaei, Y., Wen, Y., Song, Y., Zhang, Y., Li, Y., Mao, Y., Coudert, Z.D., Yan, Z., Chen, Z., Papakipos, Z., Singh, A., Srivastava, A., Jain, A., Kelsey, A., Shajnfeld, A., Gangidi, A., Victoria, A., Goldstand, A., Menon, A., Sharma, A., Boesenberg, A., Baevski, A., Feinstein, A., Kallet, A., Sangani, A., Teo, A., Yunus, A., Lupu, A., Alvarado, A., Caples, A., Gu, A., Ho, A., Poulton, A., Ryan, A., Ramchandani, A., Dong, A., Franco, A., Goyal, A., Saraf, A., Chowdhury, A., Gabriel, A., Bharambe, A., Eisenman, A., Yazdan, A., James, B., Maurer, B., Leonhardi, B., Huang, B., Loyd, B., De Paola, B., Paranjape, B., Liu, B., Wu, B., Ni, B., Hancock, B., Wasti, B., Spence, B., Stojkovic, B., Gamido, B., Montalvo, B., Parker, C., Burton, C., Mejia, C., Liu, C., Wang, C., Kim, C., Zhou, C., Hu, C., Chu, C., Cai, C., Tindal, C., Feichtenhofer, C., Gao, C., Civin, D., Beaty, D., Kreymer, D., Li, D., Adkins, D., Xu, D., Testuggine, D., David, D., Parikh, D., Liskovich, D., Foss, D., Wang, D., Le, D., Holland, D., Dowling, E., Jamil, E., Montgomery, E., Presani, E., Hahn, E., Wood, E., Le, E., Brinkman, E., Arcaute, E., Dunbar, E., Smothers, E., Sun, F., Kreuk, F., Tian, F., Kokkinos, F., Ozgenel, F., Caggioni, F., Kanayet, F., Seide, F., Florez, G.M., Schwarz, G., Badeer, G., Swee, G., Halpern, G., Herman, G., Sizov, G., {Guangyi},, {Zhang},, Lakshminarayanan, G., Inan, H., Shojanazeri, H., Zou, H., Wang, H., Zha, H., Habeeb, H., Rudolph, H., Suk, H., Aspegren, H., Goldman, H., Zhan, H., Damlaj, I., Molybog, I., Tufanov, I., Leontiadis, I., Veliche, I., Gat, I., Weissman, J., Geboski, J., Kohli, J., Lam, J., Asher, J., Gaya, J., Marcus, J., Tang, J., Chan, J., Zhen, J., Reizenstein, J., Teboul, J., Zhong, J., Jin, J., Yang, J., Cummings, J., Carvill, J., Shepard, J., McPhie, J., Torres, J., Ginsburg, J., Wang, J., Wu, K., U, K.H., Saxena, K., Khandelwal, K., Zand, K., Matosich, K., Veeraraghavan, K., Michelena, K., Li, K., Jagadeesh, K., Huang, K., Chawla, K., Huang, K., Chen, L., Garg, L., A, L., Silva, L., Bell, L., Zhang, L., Guo, L., Yu, L., Moshkovich, L., Wehrstedt, L., Khabsa, M., Avalani, M., Bhatt, M., Mankus, M., Hasson, M., Lennie, M., Reso, M., Groshev, M., Naumov, M., Lathi, M., Keneally, M., Liu, M., Seltzer, M.L., Valko, M., Restrepo, M., Patel, M., Vyatskov, M., Samvelyan, M., Clark, M., Macey, M., Wang, M., Hermoso, M.J., Metanat, M., Rastegari, M., Bansal, M., Santhanam, N., Parks, N., White, N., Bawa, N., Singhal, N., Egebo, N., Usunier, N., Mehta, N., Laptev, N.P., Dong, N., Cheng, N., Chernoguz, O., Hart, O., Salpekar, O., Kalinli, O., Kent, P., Parekh, P., Saab, P., Balaji, P., Rittner, P., Bontrager, P., Roux, P., Dollar, P., Zvyagina, P., Ratanchandani, P., Yuvraj, P., Liang, Q., Alao, R., Rodriguez, R., Ayub, R., Murthy, R., Nayani, R., Mitra, R., Parthasarathy, R., Li, R., Hogan, R., Battey, R., Wang, R., Howes, R., Rinott, R., Mehta, S., Siby, S., Bondu, S.J., Datta, S., Chugh, S., Hunt, S., Dhillon, S., Sidorov, S., Pan, S., Mahajan, S., Verma, S., Yamamoto, S., Ramaswamy, S., Lindsay, S., Lindsay, S., Feng, S., Lin, S., Zha, S.C., Patil, S., Shankar, S., Zhang, S., Zhang, S., Wang, S., Agarwal, S., Sajuyigbe, S., Chintala, S., Max, S., Chen, S., Kehoe, S., Satterfield, S., Govindaprasad, S., Gupta, S., Deng, S., Cho, S., Virk, S., Subramanian, S., Choudhury, S., Goldman, S., Remez, T., Glaser, T., Best, T., Koehler, T., Robinson, T., Li, T., Zhang, T., Matthews, T., Chou, T., Shaked, T., Vontimitta, V., Ajayi, V., Montanez, V., Mohan, V., Kumar, V.S., Mangla, V., Ionescu, V., Poenaru, V., Mihailescu, V.T., Ivanov, V., Li, W., Wang, W., Jiang, W., Bouaziz, W., Constable, W., Tang, X., Wu, X., Wang, X., Wu, X., Gao, X., Kleinman, Y., Chen, Y., Hu, Y., Jia, Y., Qi, Y., Li, Y., Zhang, Y., Zhang, Y., Adi, Y., Nam, Y., {Yu},, {Wang},, Zhao, Y., Hao, Y., Qian, Y., Li, Y., He, Y., Rait, Z., DeVito, Z., Rosnbrick, Z., Wen, Z., Yang, Z., Zhao, Z. and Ma, Z., 2024. arXiv [cs.AI].

    杂项

    *在 Google DeepMind 完成的工作,现就职于 MatX。

    引用

    在学术场合引用,请按如下格式:

        Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.
    

    或使用以下 BibTeX 条目:

        @article{scaling-book,
          title = {How to Scale Your Model},
          author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
          and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
          publisher = {Google DeepMind},
          howpublished = {Online},
          note = {Retrieved from https://jax-ml.github.io/scaling-book/},
          year = {2025}
        }