🔗 英文原文: https://jax-ml.github.io/scaling-book/jax-stuff/
✍️ 翻译: 北极的树
微信二维码 微信公众号

在 JAX 中为 TPU 编程

如何扩展你的模型》第10部分 (第9部分:性能分析 | 第11部分:结论)

如何使用 JAX 高效地为 TPU 编程!本节大部分内容摘自此处。您可以在 Google Colab 上使用免费的 TPU 运行本节中的代码示例。

JAX 中的并行性是如何工作的?

JAX 支持三种多设备编程的思想流派:

  1. 编译器,你来掌舵! 让 XLA 编译器自动对数组进行分区,并决定添加何种通信来支持给定的程序。这使你能够将一个在单设备上运行的程序,无需任何修改,就能自动地在数千个设备上运行。
  2. JAX,你来掌舵! 自动并行性很棒,但有时编译器会做出一些疯狂的举动。显式分片允许你像往常一样编写单设备代码,但由 JAX 来处理分片传播(而不是编译器)。这意味着当 JAX 不清楚你的意图时,它可以请求你进行澄清。
  3. 该死,就让我写我想写的! 虽然编译器很好用,但它们有时会做错事,添加一些你并不打算使用的通信。有时我们希望明确指定要运行的确切通信。
模式 视图? 显式分片? 显式集合操作?
自动 全局
显式 全局
手动 每设备

相应地,JAX 为每种模式都提供了 API:

  1. jax.jit (使用 Auto 网格轴) 允许你使用任何现有的 JAX 函数,并用分片的输入来调用它。然后,JAX 会使用 XLA 的 Shardy 编译器来自动并行化程序。当需要支持现有操作时,XLA 会为你添加通信操作(AllGathers、ReduceScatters、AllReduces 等)。虽然它并不完美,但通常能在无需修改代码的情况下,很好地将你的程序自动扩展到任意数量的芯片上。
  2. jax.jit 使用 Explicit 网格轴看起来与(1)类似,但它让 JAX 而不是 XLA 来处理分片传播。这意味着数组的分片实际上是 JAX 类型系统的一部分,当 JAX 检测到模糊的通信时会报错,并让用户来解决。
  3. jax.shard_map 是更手动的对应方案。你获得的是程序的设备本地视图,并且必须显式地编写任何你想要的通信。有一个分片数组,并希望每个设备上都有完整的数据?添加一个 jax.lax.all_gather。想在所有设备上对一个数组求和?添加一个 jax.lax.psum (一个 AllReduce)。编程更难,但做错事的可能性要小得多。

自动分片模式

jax.jit 在 JAX 内部扮演两个角色。顾名思义,它会“即时”将一个函数从 Python 编译成字节码(通过 XLA/HLO/LLO),使其运行得更快。但如果输入是分片的,或者用户指定了 in_shardingout_sharding,它还会让 XLA 将计算分布到多个设备上,并根据需要添加通信。例如,以下是如何使用 jax.jit 编写一个分片矩阵乘法:

import jax
import jax.numpy as jnp

# 在一个 TPU v5e 4x2 上运行。这为硬件的两个物理轴分配了名称。
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))

# 这告诉 JAX 对所有操作都使用这个网格,所以你只需指定 PartitionSpec P 即可。
jax.set_mesh(mesh)

# 我们创建一个矩阵 W 和输入激活 In,它们被分片到我们的设备上。
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('X', 'Y')))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('Y', None)))

def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

# 我们可以在这里显式编译分片的矩阵乘法函数。这会添加所有
# 必要的通信(例如,在矩阵乘法之后进行 AllReduce)。
jit_matmul = jax.jit(matmul_square, out_shardings=jax.P('X', None)).lower(In, W).compile()

out = jit_matmul(In, W)

这将在任何分片策略下自动运行,并将计算分区到我们的设备上。但是在硬件层面到底发生了什么呢?

  1. 首先,我们在设备上创建分片的 In 和 W注意我们是如何做到这一点的。这是一种创建具有特定分片数组的方法(即通过向创建函数添加 device 参数)。另一种方法是使用 `jnp.array(....)` 正常创建一个数组,然后执行例如 `jax.device_put(..., P('x', 'y'))`。还有一种方法是编写一个函数来创建你想要的数组,然后使用你想要的 `out_shardings` 对其进行 jit 编译。。W 沿着收缩维度进行 2 路分片,而 In 则进行 4 路分片(同时沿着收缩维度和输出维度)。这对应于 W[DX, F] 和 In[BX, DY] 的分片,也就是一种模型和数据并行性。
  2. 如果我们在本地(即单个设备上)运行,matmul_square 会简单地对输入进行平方运算并执行一个简单的矩阵乘法。但因为我们将 out_shardings 指定为 P('X', None),输出将沿着批次维度分片,但在模型维度上是复制的,并且需要一个 AllReduce 来计算。

使用我们前面章节的表示法,这可能会执行类似以下的操作

  1. Out[BX, F] { UY } = In[BX, DY] *D W[DY, F]
  2. Out[BX, F] = AllReduce(Out[BX, F] { UY })

jax.jit 会自动为我们添加这个!我们实际上可以用 jit_matmul.as_text() 打印出 HLO,并看到以下 HLO(经过大幅缩写):

# 这个融合操作是分片输入和矩阵的实际矩阵乘法
%fusion = bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} fusion(bf16[2,1024]{1,0:T(4,128)(2,1)} %param, bf16[8192,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done)

# 我们在设备间对部分求和的结果进行归约
ROOT %AllReduce = bf16[2,8192]{1,0:T(4,128)(2,1)} AllReduce(bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} %fusion)

我们可以看到上面的矩阵乘法(融合操作)和 AllReduce。请特别注意形状。bf16[2, 1024] 是激活的本地视图,因为我们的 batch_size=8 被分割到 4 个设备上,而我们的 d_model=2048 同样被分割成 2 路。

这简直太神奇了! 无论我们的程序多么复杂,Shardy 和 jit 都会尝试为所有中间激活找到分片方式,并根据需要添加通信。话虽如此,Shardy 也有其缺陷。它可能会犯错。有时你会查看性能剖析,发现出了问题。一个巨大的 AllGather 占用了 80% 的时间,而这本是不必要的。当这种情况发生时,我们可以尝试通过使用 jax.lax.with_sharding_constraint 显式地注释中间张量来纠正编译器。例如,对于两个矩阵乘法,我可以用以下方式强制中间激活沿着 y 维度进行分片(但这不一定是个好主意):

import jax
import jax.numpy as jnp

mesh = jax.make_mesh((4, 2), ('X', 'Y'))

def matmul(x, Win, Wout):
  hidden = jnp.einsum('bd,df->bf', x, Win)
  hidden = jax.lax.with_sharding_constraint(hidden, jax.P('x', 'y'))
  return jnp.einsum('bf,df->bd', hidden, Wout)

在自动分区世界里,通过 jax.lax.with_sharding_constraint 控制中间分片构成了大约 60% 的 JAX 并行编程。但“挑逗编译器”是出了名的不好玩的编程模型。你可能注释了每个中间变量,但仍然不知道是否会得到正确的结果。那么,如果 JAX 本身能够处理和控制分片传播呢?

显式分片模式

显式分片(或“类型中的分片”)看起来很像自动分片,但分片传播发生在 JAX 层面!每个 JAX 操作都有一个分片规则,它接收操作参数的分片方式,并为操作结果生成一个分片方式。你可以使用 jax.typeof 查看结果的分片:

import jax
import jax.numpy as jnp
import jax.sharding as shd

# 在一个 TPU v5e 2x2 上运行。这为硬件的两个物理轴分配了名称。
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'),
                                       axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))

# 这告诉 JAX 对所有操作都使用这个网格,所以你只需指定 PartitionSpec P 即可。
jax.set_mesh(mesh)

x = jax.device_put(np.arange(16).reshape(8, 2), P('X', 'Y'))

@jax.jit
def f(x):
  print(jax.typeof(x))  # bfloat16[8@X,2@Y]
  out = x * 2
  print(jax.typeof(out))  # bfloat16[8@X,2@Y]
  return out

f(x)

如你所见,JAX 将分片从输入 (x) 传播到了输出 (x),这可以在追踪时通过 jax.typeof 进行检查。对于大多数操作,这些规则简单明了,因为只有一个合理的选择(例如,逐元素操作保持相同的分片)。但对于某些操作,如何对结果进行分片是模糊的,在这种情况下,JAX 会抛出一个追踪时错误,并要求程序员显式地提供一个 out_sharding 参数(例如 jnp.einsum、jnp.reshape 等)。让我们看另一个有冲突的例子:

# 我们创建一个矩阵 W 和输入激活 In,它们被分片到我们的设备上。
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, out_sharding=jax.P('X', 'Y'))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, out_sharding=jax.P('Y', None))

@jax.jit
def matmul_square(In, W):
  print(jax.typeof(In))  # bfloat16[8@X, 2048@Y]
  print(jax.typeof(W))  # bfloat16[2048@Y, 8192]
  return jnp.einsum('bd,df->bf', jnp.square(In), W)

matmul_square(In, W)  # 这会报错

这段代码会报错 Contracting dimensions are sharded and it is ambiguous how the output should be sharded. Please specify the output sharding via the out_sharding parameter. Got lhs_contracting_spec=('Y',) and rhs_contracting_spec=('Y',)

这太棒了,因为 einsum 的输出应该如何分片是模糊的。输出分片可以是:

与自动模式不同,显式模式在检测到模糊的通信时会报错,并要求用户解决它。所以在这里你可以这样做:

@jax.jit
def matmul_square(In, W):
  return jnp.einsum('bd,df->bf', jnp.square(In), W, out_sharding=P('X', 'Y'))

out = matmul_square(In, W)
print(jax.typeof(out))  # bfloat16[8@X,8192@Y]

自动模式和显式模式可以通过 jax.sharding.auto_axesjax.sharding.explicit_axes API 组合使用。想了解更多信息,可以阅读这篇很棒的文档

shard_map: 对程序的显式并行性控制

如果说 Shardy 是“编译器掌舵”模式,那么 jax shard_map 则将一切都交到你手中。你像在 jax.jit 中一样指定输入的分片,但之后你需要显式地编写所有通信。jax.jit 给你留下的是程序的全局跨设备视图,而 shard_map 给你的是一个本地的每设备视图。

这里有一个例子。试着推断一下这个函数是做什么的:如果你想在 colab 中通过模拟网格来自己尝试,你可以使用以下单元格 `import jax; jax.config.update('jax_num_cpu_devices', 8)`

import jax
import jax.numpy as jnp
import jax.sharding as shd

mesh = jax.make_mesh((2, 4), ('x', 'y'), (shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)

x = jnp.arange(0, 512, dtype=jnp.int32, out_sharding=P(('x', 'y')))

# 这个函数将对数组的 1/8 进行操作。
@jax.shard_map(in_specs=P(('x', 'y')), out_specs=P())
def slice_and_average(x):
  assert x.shape == (512 // 8,)
  return jax.lax.pmean(x[:4], axis_name=('x', 'y'))

out = slice_and_average(x)
assert out.shape == (4,)

这是做什么的? slice_and_average 在每个 TPU 上运行,处理数组的 1/8,我们从中切片前 4 个元素,并在整个网格上对它们求平均值。这意味着我们实际上在做 mean(x[:4], x[64:68], x[128:132], …)。这非常酷,因为在 JAX 中用其他方式表达这个操作并不容易。

为什么要这样做而不是用 jax.jit? 如果我们使用 jax.jitslice_and_average 将会看到数组的全局视图(完整的 [512,] 数组)。我们将不得不切出这个非均匀的切片,然后执行一个平均操作,而 XLA 必须正确地解释它。XLA 可能会添加错误的通信或感到困惑。在这里,我们看到的是本地视图,并且只编写我们需要的通信。

示例 [集合矩阵乘法]: 举一个更现实的例子,假设我们要实现模型并行性,其中激活最初是按模型分片的,即 A[BX, DY] * W[D, FY] -> Out[BX, FY]。天真地,我们会先对 A 进行 AllGather,然后进行一个本地矩阵乘法:

  1. A[BX, D] = AllGatherY(A[BX, DY])
  2. Out[BX, FY] = A[BX, D] *D W[D, FY]

可惜,这样做不好,因为它不允许我们将通信与计算重叠。如 Wang et al. 2023 中所述,可以通过“集合矩阵乘法”来实现重叠。算法基本如下:

我们可以用 shard_map 相当容易地实现这一点:

import functools

import jax
import jax.numpy as jnp
import jax.sharding as shd
import numpy as np

mesh = jax.make_mesh(axis_shapes=(2, 4), axis_names=('X', 'Y'),
                                       axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)

B, D, F = 1024, 2048, 8192
A = jnp.arange(np.prod((B, D))).reshape((B, D))
W = jnp.arange(np.prod((D, F))).reshape((D, F))

A = jax.device_put(A, jax.P('X', 'Y'))
W = jax.device_put(W, jax.P(None, 'Y'))

@functools.partial(jax.jit, out_shardings=jax.P('X', 'Y'))
def matmul(lhs, rhs):
  return lhs @ rhs

def collective_matmul_allgather_lhs_contracting(lhs, rhs):
  # lhs 是循环操作数;rhs 是本地操作数
  axis_size = jax.lax.axis_size('Y')  # 在这个例子中 axis_size = 4
  idx = jax.lax.axis_index('Y')

  chunk_size = lhs.shape[1]
  assert rhs.shape[0] % chunk_size == 0

  def f(i, carrys):
    accum, lhs = carrys
    rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
    # 对一个块进行矩阵乘法
    update = lhs @ rhs_chunk
    # 向左循环移位
    lhs = jax.lax.ppermute(
        lhs,
        axis_name='Y',
        perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
    )
    return accum + update, lhs

  accum = jnp.zeros((lhs.shape[0], rhs.shape[1]), dtype=lhs.dtype)
  accum = jax.lax.pvary(accum, ('X', 'Y'))
  accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs), unroll=True)

  # 在最后一次置换后计算最后一个块,以使 lhs 恢复到我们找到它时的状态
  i = axis_size - 1
  rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
  update = lhs @ rhs_chunk
  return accum + update

jit_sharded_f = jax.jit(jax.shard_map(
  collective_matmul_allgather_lhs_contracting,
  in_specs=(jax.P('X', 'Y'), jax.P(None, 'Y')), out_specs=jax.P('X', 'Y')))

shmapped_out = jit_sharded_f(A, W)
expected_out = matmul(A, W)

np.testing.assert_array_equal(shmapped_out, expected_out)

这非常巧妙!我们可以对此进行基准测试,发现它也快得多!这里是默认 jit 矩阵乘法的性能剖析,它耗时 311us,并且在开始时有一个大的阻塞式 AllGather:

这里是上面那个版本的剖析,耗时 244 us。你可以看到剖析中没有 AllGather。全都是有效的工作!我们的 FLOPs 利用率也高得多。

还值得注意的是,在收缩维度上没有分片时的矩阵乘法时间是 224us,所以我们非常接近未分片的基线。这是一个很好的例子,说明了你可能需要进行何种性能工程来提高 TPU 的利用率。要了解更多 shard_map 示例,这篇笔记很棒

现在这里有几个有用的实践问题,可以尝试用 jax.jitshard_map 来实现!

实践问题

这里有一些随机的 JAX 相关问题。我稍后会添加更多。对于所有这些问题,你都需要在 Colab 中有一定数量的 TPU。你可以使用带有 TPUv2-8 的公共 Colab。从现在开始,我们假设你有 N 个可用设备。

问题1:A 是一个激活数组,形状为 float32[SX, DY],其中 X * Y = N。请完成以下操作:

  1. 在 JAX 中编写一个函数,计算每个 (X, Y) 分片内的平均值,即返回一个大小为 [X, Y] 的数组,其中 arr[i, j] 是分片 (i, j) 上的平均值。分别用 jax.jitshard_map 实现。对每个实现进行性能分析,看看它们耗时多久。是否添加了任何通信?提示:不应该有,但有时 XLA 还是会添加。

  2. 在 JAX 中编写一个函数,对于每个分片 X 内部的某个位移,返回 roll(x, shift, axis=0) - x。我还没那么自虐,不会让你用 jax.jit 来做这个,所以只用 shard_map 实现即可。

点击此处查看答案。

第1部分:这是第1部分的解答。请注意,对于 jax.jit 的解决方案,我们必须进行相当复杂的重塑操作。

import numpy as np

import jax
import jax.numpy as jnp

P = jax.sharding.PartitionSpec

mesh = jax.make_mesh((4, 2), ('X','Y'))

average_shmap = jax.shard_map(
    lambda x: x.mean(keepdims=True),
    mesh=mesh,
    in_specs=P('X','Y'), out_specs=P('X','Y')
)

def average(x):
  X, Y = mesh.axis_sizes
  return x.reshape(X, x.shape[0] // X, Y, x.shape[1] // Y).mean(axis=(1, 3))

average_jit = jax.jit(average, out_shardings=jax.NamedSharding(mesh, P('X','Y')))

x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, P('X','Y')))

y1 = average_shmap(x)
y2 = average_jit(x)

np.testing.assert_array_equal(y1, y2)

第2部分:这是第2部分的类似解答。

import numpy as np

import jax
import jax.numpy as jnp

import functools

P = jax.sharding.PartitionSpec

mesh = jax.make_mesh((4, 2), ('X','Y'))

def shift_shmap(x, shift: int):
  shmapped = jax.shard_map(
      lambda x: jnp.roll(x, shift, axis=0),
      mesh=mesh,
      in_specs=P('X','Y'), out_specs=P('X','Y')
  )
  return shmapped(x)

@functools.partial(jax.jit, static_argnames=['shift'], out_shardings=jax.NamedSharding(mesh, P('X','Y')))
def shift_jit(x, shift: int):
  X, Y = mesh.axis_sizes
  reshaped = x.reshape(X, x.shape[0] // X, -1)
  return jnp.roll(reshaped, shift, axis=1).reshape(x.shape[0], x.shape[1])

x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, P('X','Y')))

y1 = shift_shmap(x, 5)
y2 = shift_jit(x, 5)

np.testing.assert_array_equal(y1, y2)

问题2:在这里,我们将一起构建一个基本的“专家混合”(MoE)模型。设 W: float32[EX, D, FY] 是一组 E 个“专家”矩阵。设 A: float32[SX, DY](我们的激活)并且设 B 是一组“路由分配”,其中 B[i] 是范围 [0, E) 内的一个整数,告诉我们希望用哪个矩阵来处理该激活。我们想在 JAX 中编写一个函数,返回 Out[i] = W[B[i]] @ A[i]

  1. 让我们先完全忽略分片。将所有这些张量做得足够小,以便它们能放入一个设备中。编写这个函数的本地实现。确保你不要物化一个形状为 [S, D, F] 的数组!提示:尝试将令牌排序到一个形状为 [E, S, D] 的新缓冲区中,并注意掩码(为什么我们需要第二个维度的大小为 S?)。

  2. 如果你只是对上述方法使用 jax.jit,会发生一些事情。对此进行性能分析,看看它决定进行何种通信。它需要多长时间?

  3. 你会注意到的一个问题是,上述方法很可能会在本地收集完整的激活集 A,即 AllGatherX([SX, DY])。这不仅在通信方面成本高昂,而且如果我们无法在本地容纳完整的激活集,在内存方面也是极其昂贵的。使用 shard_map 和显式通信来实现上述功能。

    1. 作为第一步,最简单的方法可能是使用一个 jax.lax.all_gather 并像(a)中那样重新排序。

    2. 作为第二步,尝试避免物化任何大小为 [E, S, D] 的数组,即尝试在一个 jax.lax.while_loop 内部使用一个 jax.lax.all_to_all 以不规则的方式执行计算。这样,你可以避免物化完整的激活并浪费计算在填充上。这比你最初的实现快多少?

  4. 大多数 MoE 模型会将输入路由到多个(k 个)专家,然后对结果进行平均。重构上述代码以实现这一点。在这种情况下,设 B: int32[S, k] 用于路由到 k 个专家。

问题3:上面那个集合矩阵乘法的例子实际上与真实的 LLM 非常相关。让我们调整这个例子来完成整个 Transformer 栈。

  1. 作为一个练习,让我们从实现一个 AllReduce 集合矩阵乘法开始,即 A[BX, DY] *D W[DY, F] -> Out[BX, F]。注意输出不是复制的。上面讨论了朴素算法,基本上就是一个本地矩阵乘法后跟一个 AllReduce。尝试制作一个通信重叠的“集合”版本的此操作。提示:在输出维度上进行分块,并可以随意使用 jax.lax.psum(即 AllReduce)。 注意:由于 XLA 处理此问题的方式,它实际上可能不会比基线更快。

  2. 上面 AllReduce 集合矩阵乘法的补充是 ReduceScatter 集合矩阵乘法,如 Tmp[BX, FY] *F W2[FY, D] -> Out[BX, DY]。这发生在 Transformer 中的下投影矩阵中。在 JAX 中实现一个集合的、重叠的版本。注意只传递所需的最少量数据。提示:尝试在累加结果时对其进行置换。

  3. 将这两者组合成一个端到端的 Transformer 块,该块执行 In[BX, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[BX, DY] 并带有重叠的通信。和之前一样,由于我们在此省略了一个非线性操作,我们不能先计算 W_{in} \cdot W_{out} 这比 jax.jit 实现快多少?

问题4:上面实现的所有集合矩阵乘法都是单向的:它们只在一个方向上进行置换。重写集合 AllReduce 矩阵乘法和集合 ReduceScatter 矩阵乘法,以使用双向通信。它们快了多少?

第10部分到此结束。基本上就是这样了!要查看最终结论和进一步阅读,请点击此处

脚注

  1. 注意我们是如何做到这一点的。这是一种创建具有特定分片数组的方法(即通过向创建函数添加 device 参数)。另一种方法是使用 `jnp.array(....)` 正常创建一个数组,然后执行例如 `jax.device_put(..., P('x', 'y'))`。还有一种方法是编写一个函数来创建你想要的数组,然后使用你想要的 `out_shardings` 对其进行 jit 编译。[↩]
  2. 如果你想在 colab 中通过模拟网格来自己尝试,你可以使用以下单元格 `import jax; jax.config.update('jax_num_cpu_devices', 8)`[↩]
  3. 和之前一样,由于我们在此省略了一个非线性操作,我们不能先计算 W_{in} \cdot W_{out}[↩]

杂项

*工作于 Google DeepMind 完成,现就职于 MatX。

引用

在学术场合中引用,请按以下格式引用此作品:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

或使用 BibTeX 条目:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }