🔗 英文原文: https://jax-ml.github.io/scaling-book/applied-inference/
✍️ 翻译: 北极的树
微信二维码 微信公众号

在 TPU 上服务 LLaMA 3-70B

第 8 部分,如何扩展你的模型 (第 7 部分:推理 | 第 9 部分:性能分析)

让我们深入探讨如何在 TPU v5e 上服务 LLaMA 3-70B 模型。在Roofline模型下,服务不同模型的成本有多高?它们的 KV 缓存有多大?我们应该使用多大的批量大小?在推理过程中,参数和激活值是如何分片的?让我们通过一些粗略估算,来分析生产环境中的延迟和吞吐量。

本节将探讨服务 LLaMA-3 需要什么,以及如何高效地完成。与之前的“应用”部分一样,请在查看答案之前,尝试用纸笔自己算出答案!

LLaMA 服务详解

让我们回顾一下 LLaMA 3-70B 的结构(参考第 6 节):

超参数
nlayers (L) 80
dmodel (D) 8,192
dff (F) 28,672
nheads (N) 64
nkv heads (K) 8
dqkv (H) 128
nembeddings (V) 128,256

让我们从一个简单的问题开始:我们应该在什么硬件上进行服务? 答案基本上是,选择 FLOPs/美元 性价比最高的硬件。这并非总是正确,有时更高的 HBM 或 ICI 带宽比 FLOPs 更关键,但这不失为一个好的启发式方法。 因此,我们通常希望在 TPU v5e 上进行服务,这是我们目前专用的推理芯片(成本数据截至 2025 年 2 月,来源于 Google Cloud 定价):

TPU 类型 bfloat16 FLOPs/秒 Google Cloud 美元/小时 FLOPs / $
H100 9.9e14 $10.8 3.3e17
v5p 4.59e14 $4.2 3.9e17
v5e 1.97e14 $1.2 5.8e17

每个 TPU v5e 拥有 16GB 的 HBM,这要求我们对模型进行相当积极的分片。让我们从思考一些可能对我们很重要的基本量开始:

问题: LLaMA 3-70B 每个 token 的 KV 缓存有多大?你可以假设我们用 int8 存储它们。这决定了在给定拓扑上我们的批量大小能有多大。

想清楚后点击这里!

LLaMA 3-70B 有 8 个 KV 头,所以每个 token 的大小是 2 * K * H * L = 2 * 8 * 128 * 80 = 160kB

注意这个大小! 如果我们的序列长度是 32k 个 token(这很常见),那么每个序列会使用 162e3 * 32,768 = 5.3GB / sequence。对于 BS=240,这就是 1.3TB!由于每个 TPU v5e 只有 16GB 内存,我们大约需要 (70e9 + 1.3e12) / 16e9 = 86 个 TPU v5e 芯片才能容纳这么多内存。另外请注意,这个大小与 70GB 的模型参数相比是多么巨大。

问题: 假设我们希望以批量大小 32 和序列长度 8192 来服务 L3 70B,并且所有内容(参数和 KV)都使用 int8。这将使用多少总内存?我们可以用多小的切片来服务它?

答案

由于我们的 KV 缓存使用 int8 存储,大小为 160e3 字节,所以总 KV 内存为 160e3 * 8192 * 32 = 41.9e9 字节。我们的参数大小为 70e9 字节,因为每个参数占 1 字节。因此,总内存使用量为 41.9e9 + 70e9 = 112GB

我们能使用的最小切片将有 112e9 / 16e9 = 7 个 TPU,或者(向上取整到偶数大小)一个 TPU v5e 4x2。这会非常紧张,考虑到其他开销,我们可能无法完全装下,所以我们可能至少需要一个 4x4(或者减小批量大小)。

问题: 在这个批量大小和量化下,在一个 TPU v5e 4x2 上,我们预计每个解码步骤的延迟大约是多少?吞吐量(tokens / sec / chip)是多少?如果换成 4x4 呢?假设我们用 bfloat16 执行 FLOPs,并且所有东西都完全分片。

答案

我们可以调用上一节的公式

Theoretical Step Time (General)=Batch Size×KV Cache SizeTotal Memory BandwidthAttention (always bandwidth-bound)+max(2×Batch Size×Parameter CountTotal FLOPs/s,Parameter SizeTotal Memory Bandwidth)MLP (can be compute-bound)

这里我们的临界批量大小约为 120,因为我们的参数是 int8,但 FLOPs 是 bfloat16。我们也可以手动计算右侧的最大值,但这基本上是我们已经做过好几次的计算了。所以,我们的矩阵乘法和 FLOPs 都已深陷内存带宽限制的区域。

严格从内存带宽来看,我们的步长时间基本上是 (KV size + param size) / (8 * HBM bandwidth) = 112e9 / (8 * 8.1e11) = 17ms所以理论上我们的步长时间约为 17ms。 我们的吞吐量将是 32 / .017 = 1882 tokens / sec,或者 1882 / 8 = 235 tokens / sec / chip

这里有一个需要注意的地方,就是检查我们的矩阵乘法是否会受到 ICI 限制。我们可以在这里分配 2 个轴,所以理论上当 Y > 2 * F / 2200 = 2 * 28672 / 2200 = 26 时我们会受到 ICI 限制,所以我们是安全的!

如果我们在 4x4 上运行,ICI 方面仍然没问题,所以我们的延迟会降至 17 / 2 = 8.5ms,但每芯片的吞吐量将保持不变。

思考吞吐量

让我们花点时间纯粹地思考吞吐量。当我们优化吞吐量时,我们希望达到计算密集型,这意味着我们接近于利用所有的 TPU MXU 容量。通常,这意味着我们希望批量大小尽可能大,以便我们做尽可能多的工作。

问题: 在 TPU v5e 上,使用 bfloat16 的权重和激活值,我们的批量大小需要多大才能使矩阵乘法达到计算密集型?如果我们使用 int8 权重,但在 bfloat16 中执行 FLOPs 呢?如果使用 int8 权重和 int8 FLOPs 呢?

答案

如第 7 节所述,对于任何 B \ll D, F 的 bfloat16 矩阵乘法,我们有

Tmath>Tcomms2BDF2DFTPU bfloat16 FLOPs/sHBM bandwidth=240

当我们的权重是 int8 时,分母上会少一个因子 2,所以我们有 2BDF / DF = 2B > 240,等价于 B > 120,这是之前临界批量大小的一半。这对我们非常有帮助!当我们使用 int8 权重和 int8 FLOPs 时,我们必须使用 TPU FLOPs/s 的 int8 值,该值从 bfloat16 的 1.97e14 增加到 3.94e14,几乎翻了一番。这意味着我们又回到了起点,大约在 B > 240

int8 权重和 bfloat16 FLOPs 的情况相当普遍,因为无损地量化参数通常比进行低精度算术运算更容易。

问题: 使用 8k 上下文,在 bfloat16、int8 和 int4(KV 缓存和参数)下,我们可以用来服务 LLaMA 3-70B 的最小 TPU v5e 拓扑是什么?在这个问题中,你可以认为 KV 缓存的大小可以忽略不计。

答案

这很简单!如果我们能接受一个很小的批量大小,那么唯一的限制就是将参数内存装入 HBM,即 ceil(num_params * sizeof(dtype) / HBM per TPU,或者 ceil(70e9 * sizeof(dtype) / 16e9),向上取整到最近的合理拓扑(2 的某个倍数):

dtype 参数大小 KV 大小/token (字节) 最小 TPU v5e 数量 实际最小切片 剩余用于 KV 缓存的 HBM 8k 上下文下的 KV 缓存数量
bf16 140GB 324kB 8.75 4x4 = 16 chips 116 43
int8 70GB 162kB 4.38 4x2 = 8 chips 68 52
int4 45GB 81kB 2.81 2x2 = 4 chips 19 67

这太酷了!它告诉我们,如果愿意,我们可以把 LLaMA 70B 部署在一个 TPU v5e 2x2 上。但你会注意到 KV 缓存的数量非常少。那就是我们的批量大小!这意味着我们的 FLOPs 利用率会非常糟糕。我们会非常乐意使用更大的拓扑,以便将我们的批量大小推高到 240。

问题: 假设我们使用这些拓扑能容纳的最大批量大小,我们预计每个生成步骤的延迟是多少?

答案

这也很简单,因为我们选择的批量大小会填满我们所有的 HBM!这只是一个将整个 TPU v5e 的字节加载到 MXU 中需要多长时间的问题。这只是 v5e HBM / v5e HBM memory bandwidth = 16GB / 8.2e11 = 19ms,所以是 19ms / 步。假设我们的生成中位长度为 512 个 token,那么每个解码大约需要 9 秒。请注意,使用更小的批量大小可以获得略好的延迟,例如,如果我们只考虑 int4 的模型参数,我们的最小延迟大约是 10ms / 步,因为 HBM 不再是满的。

要点:我们总是可以通过计算将所有模型参数从 HBM 加载到 MXU 所需的时间来为解码延迟设定一个下限。当我们的 KV 缓存很小时,你可以把每一层看作是逐块加载权重然后丢弃。除非我们使用大的批量大小或大量的设备间通信,这通常是一个合理的界限(在 1.5 倍以内)。当批量大小更大时,我们还需要对 KV 缓存加载进行建模,因为它会主导参数加载。

同样,在 FLOPs 限制的场景(例如训练或大批量推理)中,我们可以使用 Total FLOPs/(NC)=2param countB/(NC) 下限,该下限假设没有通信。

问题: 对于以上每种情况,这能给我们带来多少每芯片的吞吐量(以 查询数/芯片 为单位)?你可以假设我们的解码中位长度为 512 个 token。

答案

这是一个重要的问题,因为它与 成本/token 精确相关。

根据我们对解码中位长度的假设,我们的吞吐量就是 B/(per-step latencymedian stepsN)43/(0.019512N)。这大约给了我们 (4.42/N) QPS,所以代入 N 我们得到:

dtype QPS / 芯片
bfloat16 0.27
int8 0.66
int4 1.72

请注意,这个结果相当乐观,因为它完全忽略了前向传播的工作内存(分配给激活值和注意力机制的内存)。在使用 Flash Attention 的情况下,这并非荒谬,但也并不现实。真实数字可能大约是这个的一半。为了获得绝对最大的吞吐量,我们可能需要将芯片数量增加一倍以上,并显著增加批量大小。

问题: 如果我们将上述每个示例的拓扑加倍,我们的峰值吞吐量会如何变化?

答案

如果我们在 bfloat16 中使用 4x8 切片,我们将有 186GB 剩余用于 KV 缓存,这可以让我们将批量大小增加到 161。然后,由于我们的步长时间保持不变,我们的吞吐量将是 16.54 / num_chips,或者

dtype QPS / 芯片
bfloat16 (on 4x8) 0.51
int8 (on 4x4) 1.03
int4 (on 2x4) 2.06

进一步增加会带来更大的收益!关键要点是,在所有情况下,最小的拓扑并不总是性能最高的拓扑,特别是当我们受到 KV 缓存大小限制时。

问题: 现在让我们深入探讨分片问题。假设我们想在 TPU v5e 4x8 上用 bfloat16 服务模型。在生成过程中,我们应该为模型在 TPU v5e 4x8 上使用什么样的分片策略?我们能避免受到通信限制吗?

答案

正如上一节所讨论的,在生成过程中,我们实际上只有一种分片选择:模型并行性。在达到通信瓶颈之前,我们可以做多少模型并行?正如我们在上一节中讨论的,我们的模型大约在

Y>FMY2200

时会受到通信限制。对于 LLaMA 3-70B,我们有 F = 28,672,所以如果我们进行 2 轴的模型分片,这大约会给我们 Y=286722/2200=26,所以总的来说,我们可以在不受到通信限制的情况下扩展到大约 16 个芯片,这让我们能使用 4x4 但不能使用 4x8。通常,由于我们无法完美地重叠计算,即使是这个估计也过于乐观了。

要点:我们实际上无法在 4x8 上使用纯模型并行性进行服务。 我们在这里能做的最好的是 4x2 或者也许是 4x4。

然而,正如我们所讨论的,当批量大小较小时,我们通常可以进行更多的模型并行性,而不会显著损害吞吐量,因为我们的模型受内存带宽限制,而不是 FLOPs 限制。我们之前说过,这个值大约是 Y=F / (8\cdot B),所以如果我们使用批量大小 64,理论上在受到 ICI 限制之前,我们可以达到 Y = 28,672 / (8 * 64) = 56 路模型并行性。为了验证这一点,我们可以查看单个矩阵乘法的 T_\text{ici comms}T_\text{hbm comms}T_\text{math}。我们显然有:

Tici comms=2BDWiciThbm comms=2DFYWhbmTmath=2BDFYC

对于一个 4x8,这将给我们 T_\text{ici comms} = (2 * 64 * 8192) / 9e10 = 11usT_\text{hbm comms} = (2 * 8192 * 28,672) / (32 * 8.1e11) = 18us,以及 T_\text{math} = (2 * 64 * 8192 * 28,672) / (32 * 1.97e14) = 4us,所以理论上我们仍然受 HBM 带宽限制,这太棒了!*请注意,从 4x4 扩展到 4x8 可能对吞吐量没有帮助,但它会降低我们的延迟!*

如果我们看 int8 和 int4 的配置,我们可以用纯模型并行性来做。所以我们达到了一个点,即量化实际上除了更快的 FLOPs 之外,还给了我们一个有意义的优势:它让我们可以使用更大的批量大小,才达到通信瓶颈。*所以这个故事的结局是,我们无法在 4x8 上实现峰值吞吐量,但对于 int8 和 int4 配置,我们可以使用纯模型并行性

提示:有用的模型并行性的最大数量取决于 dff 和你对模型进行分片的轴数。根据模型大小,最大值通常在 8 到 32 之间。你可以超越这个限制来提高延迟,但会牺牲一些吞吐量。

预填充(Prefill)怎么办?

我们在这里基本上忽略了预填充,因为它要简单得多。让我们把几个概念放在一起,思考一下端到端的全貌。

问题: 假设我们在预填充期间实现了 40% 的 FLOPs 利用率。在 16 个 TPU v5e 芯片上,长度为 8192 的预填充需要多长时间?

答案

在 8k token 时,我们是稳固的计算密集型,所以我们只需要考虑 FLOPs。我们知道我们的模型有 70e9 参数,所以每次前向传播使用 2 * 70e9 * B FLOPs。假设 40% 的 MFU(FLOPs 利用率),这给我们的运行时间大约是 2 * 70e9 * 8192 / (16 * 1.97e14 * 0.4) = 0.91s。与我们之前看到的数字相比,这实际上相当长!

问题: 假设我们的预填充中位长度为 8192 token,解码中位长度为 4096 token。假设我们的生成批量大小为 32。平均每个步骤有多少序列完成解码?平均每个步骤有多少 token 从我们的 KV 缓存中被驱逐?

答案

这有点直接。由于我们的解码中位长度为 4096 token,一个序列大约每 1 / 4096 token 完成一次。给定批量大小为 32,这意味着我们每个步骤有 32 / 4096 个序列被驱逐。由于我们的 KV 缓存长度大约是 8192 + 4096,这相当于每个步骤驱逐 32 * (8192 + 4096) / 4096 = 96 个 token。通用公式是 B * (P + G) / G,其中 PG 分别是预填充和生成的长度。

问题: 假设我们进行分离式服务,预填充中位长度为 8192,解码中位长度为 512。假设预填充和生成的延迟如上文 bfloat16 部分计算。你需要什么样的预填充服务器与生成服务器的比例才能使两者都保持完全饱和?

答案

这是一个有趣的问题。让 P 表示预填充服务器的数量,G 表示生成服务器的数量。所以总的来说,这是一个流水线问题,我们以 P / prefill_latency 的速率输入序列,并以 B * G / (generate_latency * median_decode_length) 的速率消耗它们。我们之前计算出,在批量大小为 43(我们称之为 32)时,每个预填充步骤为 910ms,每个解码步骤为 19ms。因此我们需要 P / 0.91 = 32 * G / (0.019 * 512),或者 P = 3G,也就是说,我们需要的预填充服务器数量大约是生成服务器的三倍!

可视化延迟与吞吐量的权衡

继续以 LLaMA 70B 为例,让我们实际看看在生成过程中不同批量大小下的延迟和吞吐量。正如我们在上一节中为 PaLM 模型所展示的,这为我们提供了一个吞吐量/延迟的帕累托前沿。我们假设使用 16 路张量并行性,因为这是在 MLP 块中保持计算密集型的合理上限。这里我们将使用 TPU v5e 4x4 拓扑。滑动条控制序列长度,这样你可以看到更大 KV 缓存的影响。

我们可以通过将成本和延迟的来源分解为参数加载时间、KV 加载时间和 FLOPs 时间来更好地理解这一点。红色区域是我们预期在 MLP 块中达到计算密集型的区域。

这揭示了一个相当有趣的故事。你可以看到,最初,参数加载占了延迟的绝大部分,直到批量大小变得足够大,FLOPs 和 KV 加载才变得更加显著。值得注意的是,在所有大于 2048 的序列长度下,我们花在 KV 缓存加载上的时间比花在 FLOPs 上的时间还要多!因此,虽然我们可以通过增加批量大小来提高硬件利用率,但在长上下文长度下,KV 加载总是主导总步长时间。

要点: 对于 LLaMA 3-70B,在几乎所有这些配置中,我们都受到 KV 缓存内存带宽的严重限制(以及 HBM 限制),这凸显了减少 KV 缓存大小对于生成吞吐量的重要性。另外请注意,这里的延迟/吞吐量权衡仍然非常显著。

实现这个的代码相当简单。

这是计算这些Roofline的代码:

import numpy as np

num_chips = 16  # we fix 16 as the amount of total model parallelism we do
param_size = 70e9  # int8 means 1 byte per param
sequence_length = 8192  # can vary this

hbm_bandwidth = 8.20E+11  # v5e
flops = 1.97E+14  # v5e

param_size = bytes_per_param * param_count

def kv_cache_size(bs):
    return 2 * bs * 128 * 8 * 80

def min_topology(bytes):
    return 2 ** np.ceil(np.log2(bytes / 16e9))

def get_max_batch_size(max_num_chips: int = 16):
  # for num_chips in topo_sizes:
  batch_sizes = np.arange(1, 1024, 4)
  kv_sizes = kv_cache_size(sequence_length * batch_sizes)
  num_chips = min_topology(kv_sizes + param_size)
  max_idx = np.where(num_chips <= max_num_chips)[0][-1]
  return max_idx

max_idx = get_max_batch_size(num_chips, sequence_length, param_size)  # get the largest batch size that can fit
batch_sizes = np.arange(1, 512, 1)[:max_idx]
kv_sizes = kv_cache_size(sequence_length * batch_sizes)

kv_comms_time = kv_sizes / (num_chips * hbm_bandwidth)

param_comms_time = param_size / (num_chips * hbm_bandwidth)
param_comms_time = np.asarray([param_comms_time] * batch_sizes.shape[0])

flops_time = 2 * param_count * batch_sizes / (num_chips * flops)  # roughly true in a 2ND sense

mlp_time = np.maximum(flops_time, param_comms_time)
attn_time = kv_comms_time  # always bandwidth-bound for generate

latency = 1000 * (mlp_time + attn_time)
throughput = batch_sizes / (latency * num_chips)

注意我们如何非常明确地将延迟分解为两个来源:KV 加载和参数加载,以及延迟如何受限于 FLOPs 或通信,取两者中较大者。

练习题

这里有几个练习题。其中一些重复了上面已经讲过的内容,但可能在教学上很有用。

问题 1: LLaMA 3-405B 的每次前向传播每个 token 使用多少 FLOPs?假设我们受 FLOPs 限制,在 TPU v5e 的 N 个芯片上单次前向传播的下限是多少?如果我们受通信限制呢?忽略模型无法装入单个芯片的事实。

问题 2: 假设我们想用 BS240 服务 LLaMA 3-8B,使用 int8 权重和 int8 KV 缓存。 (a) 模型参数 (b) KV 缓存 和 (c) 峰值工作激活值(大约)各使用多少字节?我们可以运行这个的最小拓扑是什么?

问题 3: 你将如何在 TPU v5e 上服务 LLaMA 3-405B?假设使用 int8 权重和 bfloat16 FLOPs。假设我们有 15ms / token 的严格限制,我们能实现的最高吞吐量配置是什么?理论上的最小步长时间是多少?

第 8 部分到此结束!要深入了解 XLA 和 TPU 性能分析,请点击此处进入第 9 部分。

脚注

  1. 这并非总是正确,有时更高的 HBM 或 ICI 带宽比 FLOPs 更关键,但这不失为一个好的启发式方法。[↩]

杂项

*在 Google DeepMind 完成的工作,现就职于 MatX。

引用

在学术背景下进行署名时,请按如下方式引用本作品:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

或使用 BibTeX 条目:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }