微信公众号
TPU上LLM的系统视图 (第0部分:简介 | 第1部分:Roofline模型)
训练LLM通常感觉像炼金术,但理解和优化你的模型性能并非如此。本书旨在揭开语言模型扩展科学的神秘面纱:TPU(和GPU)是如何工作的,它们之间如何通信,LLM如何在真实硬件上运行,以及如何在训练和推理过程中并行化你的模型,使其能在大规模下高效运行。如果你曾想过‘训练这个LLM应该有多昂贵’或‘我需要多少内存来自己部署这个模型’或‘什么是AllGather’,我们希望这本书对你有所帮助。
深度学习的许多方面仍然可以归结为一种黑魔法,但优化你的模型性能并非如此——即使是在巨大规模下!相对简单的原则无处不在——从处理单个加速器到数万个加速器——理解这些原则可以让你做许多有用的事情:
预期背景:我们假设你对LLM和Transformer架构有基本的了解,但不一定了解它们如何大规模运行。你应该了解LLM训练的基础知识,最好对JAX有一些基本熟悉。一些有用的背景阅读可能包括关于Transformer架构的这篇博客文章和原始的Transformer论文。还可以查看这个列表以获取更多有用的同步和未来阅读材料。
目标与反馈:读完本书后,你应该能够轻松地为给定硬件平台上的Transformer模型估算出最佳的并行性方案,以及大致的训练和推理时间。如果你做不到,请给我们发邮件或留言!我们很想知道如何能把内容讲得更清楚。
你可能也会喜欢阅读关于NVIDIA GPU的新的第12节!
三四年前,我认为大多数机器学习研究人员不需要理解这本书中的任何内容。但如今,即使是“小”模型也运行得非常接近硬件极限,以至于进行创新性研究需要你考虑大规模下的效率。
“模型扩展”的目标是能够增加用于训练或推理的芯片数量,同时实现吞吐量的成比例线性增长。这被称为“强扩展”。尽管增加额外的芯片(“并行性”)通常会减少计算时间,但它也会带来芯片间通信增加的成本。当通信时间超过计算时间时,我们就会变得“受通信限制”,无法实现强扩展。
我们在这本书中的目标是解释TPU(和GPU)硬件是如何工作的,以及Transformer架构是如何演变以在当前硬件上表现良好的。我们希望这对于设计新架构的研究人员和致力于让当前一代LLM快速运行的工程师都有用。
本书的总体结构如下:
第1节解释了Roofline分析以及哪些因素会限制我们的扩展能力(通信、计算和内存)。第2节和第3节详细讨论了TPU的工作原理,既包括作为单个芯片,也包括——至关重要的——作为一个具有有限带宽和延迟的互连芯片链接的系统。我们将回答以下问题:
五年前,机器学习的架构景观丰富多彩——ConvNets、LSTMs、MLPs、Transformers——但现在我们主要只有Transformer
第5节:训练和第7节:推理是本文的核心,我们在这里讨论一个根本问题:给定一个特定大小的模型和一定数量的芯片,我该如何并行化我的模型以保持在“强扩展”状态?这是一个简单的问题,但答案却出人意料地复杂。从高层次来看,有4种主要的并行性技术用于在多个芯片上拆分模型(数据、张量、流水线和专家),以及一些其他技术来减少内存需求(重计算、优化器/模型分片(又名ZeRO)、主机卸载、梯度累积)。我们在这里讨论其中的许多技术。
我们希望在这些章节结束时,你应该能够为新的架构或设置自己选择合适的并行方案。第6节和第8节是将这些概念应用于流行的开源模型LLaMA-3的实践教程。
最后,第9节和第10节探讨了如何在JAX中实现其中一些想法,以及在出现问题时如何分析和调试代码。第12节是一个新章节,也深入探讨了GPU。
在整个过程中,我们尝试给你提供一些问题让你自己解决。请不要有压力去阅读所有章节或按顺序阅读。也请留下反馈。目前,这是一个草稿,并将继续修订。谢谢!
我们想感谢James Bradbury和Blake Hechtman,他们推导出了本文档中的许多想法。
这个系列可能比它需要的要长,但我们希望这不会让你望而却步。前三章是预备知识,如果熟悉可以跳过,尽管它们介绍了后面使用的符号。最后三个部分可能是最实用的,因为它们解释了如何处理真实模型。
第1部分:预备知识
第1章:Roofline分析简介。算法受到三件事的限制:计算、通信和内存。我们可以用这些来近似我们的算法将运行多快。
第2章:如何理解TPU。TPU是如何工作的?这如何影响我们可以训练和服务的模型?
第3章:分片矩阵及其乘法。在这里,我们通过我们最喜欢的操作:(分片)矩阵乘法来解释模型分片和多TPU并行性。
第2部分:Transformer
第4章:你需要知道的所有Transformer数学知识。一个Transformer在其前向和后向传播中使用了多少FLOPs?你能计算出参数数量吗?它的KV缓存大小?我们在这里详细讲解这些数学计算。
第5章:如何为训练并行化Transformer。FSDP。Megatron分片。流水线并行性。给定一定数量的芯片,我如何以尽可能高效的方式训练一个给定大小和给定批量大小的模型?
第6章:在TPU上训练LLaMA 3。我们如何在TPU上训练LLaMA 3?需要多长时间?成本是多少?
第7章:关于Transformer推理的一切。一旦我们训练好一个模型,我们就必须部署它。推理增加了一个新的考虑因素——延迟——并改变了内存的格局。我们将讨论解耦服务是如何工作的以及如何考虑KV缓存。
第8章:在TPU上部署LLaMA 3。在TPU v5e上部署LLaMA 3的成本是多少?延迟/吞吐量的权衡是什么?
第3部分:实践教程
第9章:如何分析TPU代码。真实的LLM从不像上面的理论那么简单。在这里,我们解释JAX + XLA技术栈以及如何使用JAX/TensorBoard分析器来调试和修复实际问题。
第10章:用JAX编程TPU。JAX提供了一系列神奇的API用于并行化计算,但你需要知道如何使用它们。包含有趣的例子和已解决的问题。
第4部分:结论和附加内容
第11章:结论和进一步阅读。关于TPU和LLM的总结思考和进一步阅读材料。
第12章:如何理解GPU。一个关于GPU的附加章节,介绍它们如何工作,如何联网,以及它们的Roofline与TPU有何不同。