Part 3 of How To Scale Your Model 中文版 (Part 2: TPUs | Part 4: Transformer Math)
当我们训练大型机器学习模型时, 我们必须将其参数或输入拆分 (或“分片”) 到许多加速器上. 由于 LLM 主要由矩阵乘法组成, 理解这一点归结为理解当矩阵分布在设备上时如何进行乘法. 我们基于 TPU 通信原语的成本, 发展了一个简单的分片矩阵乘法理论.
当我们在成千上万的 TPU 或 GPU 上训练一个 LLM 时, 我们抽象地做的计算与在单个设备上训练时是相同的. 不同之处在于我们的数组无法容纳在单个 TPU/GPU 的 HBM 中, 所以我们必须将它们拆分.
这是一个在 4 个 TPU 上分片的 2D 数组 A 的例子:
请注意, 分片数组仍然具有与未分片数组相同的全局或逻辑形状, 例如 (4, 128), 但它也有一个设备本地形状, 例如 (2, 64), 这给了我们每个 TPU 持有的实际字节大小 (在上图中, 每个 TPU 持有总数组的 ¼). 现在我们将 इसको推广到任意数组.
我们使用命名轴表示法的一种变体来描述张量如何以块的形式在设备上分片: 我们假设存在一个 2D 或 3D 的设备网格, 称为设备网格, 其中每个轴都被赋予了网格轴名称, 例如 X, Y 和 Z. 然后, 我们可以通过描述数组的每个命名维度如何跨物理网格轴进行分区来指定矩阵数据在设备网格上的布局. 我们称这个分配为分片.
示例 (上图): 对于上图, 我们有:
Mesh(devices=((0, 1), (2, 3)), axis_names=(‘X', ‘Y')), 这告诉我们我们有 4 个 TPU, 排列成一个 2x2 的网格, 轴名称为 $X$ 和 $Y$.综上所述, 我们知道数组的本地形状 (单个设备持有的分片的大小) 是 $(\lvert I\rvert / 2, \lvert J\rvert / 2)$, 其中 \(\lvert I\rvert\) 是 A 的第一个维度的大小, \(\lvert J\rvert\) 是 A 的第二个维度的大小.
小测验 [沿 1 个轴的 2D 分片]: 考虑一个数组 fp32[1024, 4096], 分片为 $A[I_{XY}, J]$, 网格为 {'X': 8, 'Y': 2}. 每个设备持有多少数据? 在 H100s 上从 HBM 加载这个数组需要多长时间 (假设每个芯片的内存带宽为 3.4e12)?
$A[I_{XY}, J]$ 将第一个维度 (I) 沿 X 和 Y 硬件轴进行分片. 在这个例子中, 本地形状是 $(\lvert I\rvert /(\lvert X\rvert \cdot \lvert Y\rvert), \lvert J\rvert)$. 对于给定的例子, 全局形状是 fp32[1024, 4096], 所以本地形状是 fp32[64, 4096].
由于每个 GPU 有 4 * 64 * 4096 = 1MiB 字节, 这大约需要 1e6 / 3.4e12 = 294ns, 尽管由于各种开销, 实际时间可能要长得多, 因为这个数组很小.
可视化这些分片: 让我们尝试通过查看一个分布在 4 个设备上的 2D 数据数组来可视化这些分片:
我们将矩阵的完全复制形式简单地写为 $A[I, J]$, 没有分片分配. 这意味着每个设备都包含整个矩阵的完整副本.
我们可以用一个下标网格轴来表示其中一个维度已经跨一个网格轴进行了分区. 例如, $A[I_X, J]$ 意味着 I 逻辑轴已经跨 X 网格维度进行了分区, 但 J 维度没有分区, 并且这些块在 Y 网格轴上保持部分复制.
$A[I_X, J_Y]$ 意味着 I 逻辑轴已经跨 X 网格轴进行了分区, 并且 J 维度已经跨 Y 网格轴进行了分区.
我们在下图中说明了其他可能性:
这里 $A[I_{XY}, J]$ 意味着我们将 X 和 Y 网格轴视为一个更大的扁平化维度, 并将 I 命名轴跨所有设备进行分区. 多个网格轴下标的顺序很重要, 因为它指定了分区跨网格的遍历顺序.
最后, 请注意, 我们不能将多个命名轴沿相同的网格维度进行分片. 例如, $A[I_X, J_X]$ 是一个无意义的, 禁止的分片. 一旦一个网格维度被用于分片数组的一个维度, 它在某种意义上就被“用掉了”.
小测验: 设 A 是一个形状为 int8[128, 2048] 的数组, 分片为 $A[I_{XY}, J]$, 网格为 Mesh({‘X': 2, ‘Y': 8, ‘Z': 2}) (总共 32 个设备). A 每个设备使用多少内存? A 在所有设备上总共使用多少内存?
答案: 我们的数组 A 在 X 和 Y 上分片, 在 Z 上复制, 因此每个设备的形状为 int8[128 / (2 * 8), 2048] = int8[8, 2048], 大小为 8 * 2048 = 16,384 字节. 因为它在 Z 上复制, 而在 Z 平面内它在 X 和 Y 上完全分片, 所以每个 Z 平面有一个副本, 并且有 2 个这样的平面, 所以总大小 (在所有设备上) 是 128 * 2048 * 2 = 512 KiB.
到目前为止, 我们一直避免谈论代码, 但现在是时候先睹为快了. JAX 使用一种与我们上面描述的抽象语法非常匹配的命名分片语法. 我们将在第 10 节中更多地讨论这一点, 但这里有一个快速预览. 你可以在 Google Colab 这里中玩这个, 并分析结果以查看 JAX 如何处理不同的分片. 这个代码片段做了 3 件事:
import jax
import jax.numpy as jnp
# 创建我们的网格! 我们正在一个 TPU v2-8 4x2 切片上运行, 名称为 'X' 和 'Y'.
assert len(jax.devices()) == 8
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))
# 一个帮助定义我们分片的小工具函数. PartitionSpec 是我们的
# 分片 (从轴到名称的映射).
def P(*args):
return jax.NamedSharding(mesh, jax.sharding.PartitionSpec(*args))
# 我们将 A 和 B 都在非收缩维度上进行分片, 并将 A 在收缩维度上进行分片.
A = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))
B = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P(None, 'Y'))
# 我们可以对这些分片数组进行矩阵乘法! out_shardings 告诉我们我们希望
# 输出如何分片. JAX/XLA 为我们处理其余的分片.
y = jax.jit(lambda A, B: jnp.einsum('BD,DF->BF', A, B), out_shardings=P('X', 'Y'))(A, B)
JAX 的酷之处在于这些数组的行为就像它们没有被分片一样! B.shape 会告诉我们全局或逻辑形状 (2048, 8192). 我们必须实际查看 B.addressable_shards 才能看到它是如何本地分片的. 我们可以对这些数组执行操作, JAX 会尝试找出如何广播或重塑它们以执行操作. 例如, 在上面的例子中, A 的本地形状是 [2, 1024], B 的本地形状是 [2048, 4096]. JAX/XLA 会自动在这些数组之间添加必要的通信以执行最终的乘法.
如果你有一个分布在许多设备上的数据数组, 并希望对其执行数学运算, 那么对数据和计算进行分片会带来哪些开销?
显然, 这取决于所涉及的计算.
本节的其余部分将讨论如何乘以分片矩阵. 初步来看, 这涉及到移动矩阵的块, 以便你可以完全乘以或求和每个块. 每个分片都会涉及不同的通信. 例如, $A[I_X, J] \cdot B[J, K_Y] \to C[I_X, K_Y]$ 可以在没有任何通信的情况下进行乘法, 因为收缩维度 (J, 我们实际求和的维度) 是未分片的. 然而, 如果我们希望输出是未分片的 (即 $A[I_X, J] \cdot B[J, K_Y] \to C[I, K]$), 我们就需要将 $A$ 或 $C$ 复制到每个设备 (使用 AllGather). 这两种选择有不同的通信成本, 所以我们需要计算这个成本并选择最低的一个.
为了理解这一点, 回忆一下“块矩阵”的概念可能会有所帮助, 即一个嵌套的矩阵的矩阵:
\[\begin{equation} \begin{pmatrix} a_{00} & a_{01} & a_{02} & a_{03} \\ a_{10} & a_{11} & a_{12} & a_{13} \\ a_{20} & a_{21} & a_{22} & a_{23} \\ a_{30} & a_{31} & a_{32} & a_{33} \end{pmatrix} = \left( \begin{matrix} \begin{bmatrix} a_{00} & a_{01} \\ a_{10} & a_{11} \end{bmatrix} \\ \begin{bmatrix} a_{20} & a_{21} \\ a_{30} & a_{31} \end{bmatrix} \end{matrix} \begin{matrix} \begin{bmatrix} a_{02} & a_{03} \\ a_{12} & a_{13} \end{bmatrix} \\ \begin{bmatrix} a_{22} & a_{23} \\ a_{32} & a_{33} \end{bmatrix} \end{matrix} \right) = \begin{pmatrix} \mathbf{A_{00}} & \mathbf{A_{01}} \\ \mathbf{A_{10}} & \mathbf{A_{11}} \end{pmatrix} \end{equation}\]矩阵乘法有一个很好的性质, 即当矩阵乘数用块来表示时, 乘积可以用块矩阵乘法来表示, 遵循标准规则:
\[\begin{equation} \begin{pmatrix} A_{00} & A_{01} \\ A_{10} & A_{11} \end{pmatrix} \cdot \begin{pmatrix} B_{00} & B_{01} \\ B_{10} & B_{11} \end{pmatrix} = \begin{pmatrix} A_{00}B_{00} + A_{01}B_{10} & A_{00}B_{01} + A_{01}B_{11} \\ A_{10}B_{00} + A_{11}B_{10} & A_{10}B_{01} + A_{11}B_{11} \end{pmatrix} \end{equation}\]这意味着实现分布式矩阵乘法归结为在网络上移动这些分片块, 对这些块执行本地矩阵乘法, 并对它们的结果求和. 那么问题是添加什么通信, 以及它的成本是多少.
方便的是, 我们可以将所有可能的分片归结为大约 4 种需要考虑的情况, 每种情况都有一个规则, 说明我们需要添加什么通信
你可以将这些视为只需要遵循的规则, 但理解这些规则为什么成立以及它们的成本是多少也很有价值. 我们现在将详细介绍每一个.
引理: 当乘以分片矩阵时, 计算是有效的, 并且输出遵循输入的分片, 除非收缩维度被分片或两个矩阵都沿同一轴分片. 例如, 这可以正常工作
\[\begin{equation*} \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow \mathbf{C}[I_X, K_Y] \end{equation*}\]没有任何通信, 并且得到一个跨 X 和 Y 硬件维度分片的张量. 试着想想为什么会这样. 基本上, 计算与分片无关, 因为每个批处理条目都有一些本地的收缩轴块, 它可以乘以和归约. 任何这些情况都可以正常工作, 并遵循这个规则:
\[\begin{align*} \mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I, K] \\ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K] \rightarrow &\ \mathbf{C}[I_X, K]\\ \mathbf{A}[I, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I, K_Y]\\ \mathbf{A}[I_X, J] \cdot \mathbf{B}[J, K_Y] \rightarrow &\ \mathbf{C}[I_X, K_Y] \end{align*}\]因为 A 和 B 都没有分片的收缩维度 J, 我们可以简单地执行输入的本地块矩阵乘法, 结果将已经根据所需的输出分片进行了分片. 当两个乘数都有沿同一轴分片的非收缩维度时, 这不再成立 (有关详细信息, 请参见无效分片部分).
让我们考虑当一个输入 A 沿收缩 J 维度分片, 而 B 完全复制时该怎么做:
\[\mathbf{A}[I, J_X] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]\]我们不能简单地将 A 和 B 的本地块相乘, 因为我们需要对 A 的完整收缩维度求和, 该维度分布在 X 轴上. 通常, 我们首先“AllGather” A 的分片, 以便每个设备都有一个完整的副本, 然后才与 B 相乘:
\[\textbf{AllGather}_X[I, J_X] \rightarrow \mathbf{A}[I, J]\] \[\mathbf{A}[I, J] \cdot \mathbf{B}[J, K] \rightarrow \mathbf{C}[I, K]\]这样, 实际的乘法就可以在每个设备上完全完成.
要点: 当乘以其中一个矩阵沿收缩维度分片的矩阵时, 我们通常先对其进行 AllGather, 以便收缩不再分片, 然后进行本地矩阵乘法.
请注意, 当 B 也没有沿 X 分片时, 我们也可以进行本地部分矩阵乘法, 然后对分片的部分和求和 (或 AllReduce), 在某些情况下这可能更快. 请参见问题 4 下面.
什么是 AllGather? AllGather 是我们将要讨论的第一个核心 MPI 通信原语. AllGather 移除沿一个轴的分片, 并将分布在设备上的分片重新组装到该轴上的每个设备上. 使用上面的符号, AllGather 从一组轴中移除一个下标, 例如
\[\textbf{AllGather}_{XY}(A[I_{XY}, J]) \rightarrow A[I, J]\]我们不必移除给定维度的所有下标, 例如 \(A[I_{XY}, J] \rightarrow A[I_Y, J]\) 也是一个 AllGather, 只是只在一个轴上进行. 另请注意, 我们也可能希望使用 AllGather 来移除非收缩维度分片, 例如在矩阵乘法中:
\[A[I_X, J] \cdot B[J, K] \rightarrow C[I, K]\]我们可以先对 A 进行 AllGather 以移除输入分片, 或者我们可以进行分片矩阵乘法, 然后对结果 C 进行 AllGather.
AllGather 实际上是如何执行的? 为了在一个 TPU 轴 (一个环) 周围执行一维 AllGather, 我们基本上让每个 TPU 将其分片在一个环周围传递, 直到每个设备都有一个副本.
我们可以单向或双向进行 AllGather (上图显示了双向). 如果我们单向进行, 每个 TPU 会在环周围发送大小为 $\text{bytes} / N$ 的块, 共 $N - 1$ 次跳跃. 如果我们双向进行, 我们有 $\lceil \frac{N}{2} \rceil$ 次跳跃, 大小为 $2 \cdot \text{bytes} / N$.
这需要多长时间? 让我们以双向 AllGather 为例, 计算它需要多长时间. 设 \(V\) 为数组中的字节数, $X$ 为收缩维度上的分片数. 那么从上图中, 每个跳跃在每个方向上发送 $V / \lvert X\rvert$ 字节, 所以每个跳跃需要
\[T_{hop} = \frac{2 \cdot V}{X \cdot W_\text{ici}}\]其中 $W_\text{ici}$ 是双向 ICI 带宽.
请注意, 这不依赖于 $X$! 这有点惊人, 因为这意味着即使我们的 TPU 只是本地连接的, 连接的局部性也不重要. 我们只是受每个链接速度的瓶颈.
要点: 当在受吞吐量限制的情况下执行 AllGather (或 ReduceScatter 或 AllReduce) 时, 实际的通信时间仅取决于数组的大小和可用带宽, 而不取决于我们的数组分片的设备数量!
关于 ICI 延迟的说明: 无论数据量大小, 每个 ICI 链接上的每次跳跃都有一些固有的开销. 这通常在 1us 左右. 这意味着当我们的数组 \(A\) 非常小并且每次跳跃的时间少于 1us 时, 我们可以进入一个“受延迟限制”的状态, 此时计算确实依赖于 $X$.
设 \(T_\text{min}\) 为单次跳跃的最小时间. 那么
\[T_{hop} = \max \left[ T_{min}, \frac{2 \cdot V}{X \cdot W_\text{ici}} \right]\] \[T_{total} = \max \left[ \frac{T_{min} \cdot X}{2}, \frac{V}{W_\text{ici}} \right]\]因为我们执行 $X / 2$ 次跳跃. 对于大型归约或收集, 我们完全受带宽限制. 我们发送的数据量如此之大, 以至于每次跳跃的开销基本上可以忽略不计. 但对于小数组 (例如, 从模型中采样时), 这不可忽略, ICI 带宽也不相关. 我们完全受延迟限制. 另一种说法是, 给定一个特定的 TPU, 例如具有 4.5e10 单向 ICI 带宽的 TPU v5e, 发送任何小于 4.5e10 * 1e-6 = 45kB 的缓冲区都将受延迟限制.
这是一个在 TPU v5e 8x16 切片上 AllGather 带宽的实证测量. 数组在 16 轴上分片, 因此它有一个完整的双向环.
请注意, 我们只达到了声称峰值带宽 (4.5e10) 的约 95%, 并且我们在约 10MB 时达到了这个峰值, 当 16 路分片时, 每个设备约 500kB (旁注: 这比 GPU 好得多).
当我们跨多个轴进行 AllGather 时会发生什么? 当我们跨多个轴进行收集时, 我们有多个 ICI 维度可以进行收集. 例如, AllGatherXY([B, DXY]) 在两个硬件网格轴上操作. 这将可用带宽增加了 $N_\text{axes}$ 倍.
一般来说, 我们有
\[T_{total} = \max \left[ \frac{T_{min} \cdot \sum_{i} |X_i|}{2}, \frac{V}{W_\text{ici} \cdot N_\text{axes}} \right]\]其中 \(\sum_i \lvert X_i \rvert / 2\) 是 TPU 网格中最长路径的长度.
小测验 2 [AllGather 时间]: 使用第 2 部分中的数字, 在具有 2D 网格 {'X': 8, 'Y': 4} 的 TPUv5e 上执行 AllGatherY([EY, F]) → [E, F] 需要多长时间, 其中 \(E = 2048\), \(F = 8192\) (bfloat16)? 如果 \(E=256, F=256\) 呢?
答案: 让我们从计算一些基本量开始:
1) TPU v5e 的每个轴都有 4.5e10 字节/秒的单向 ICI 带宽. 2) 在 bfloat16 中, 对于 (a), 我们有 $A[E_Y, F]$, 所以每个设备持有一个形状为 bfloat16[512, 8192] 的数组, 大小为 512 * 8192 * 2 = 8.4MB. 总数组大小为 2048 * 8192 * 2 = 34MB.
对于第 (1) 部分, 我们可以使用上面的公式. 由于我们是在一个轴上执行 AllGather, 我们有 $T_{\text{comms}} = \text{34e6} / \text{9e10} = \text{377us}$. 为了检查我们是否不受延迟限制, 我们知道在一个大小为 4 的轴上, 我们最多有 3 次跳跃, 所以我们的延迟限制大约是 3us, 所以我们离得不近. 然而, TPU v5e 只有在一个轴的大小为 16 时才有环绕连接, 所以在这里我们实际上无法进行完全双向的 AllGather. 我们需要 3 次跳跃才能让数据从边缘到达另一边, 所以理论上我们更像是 $T_{\text{comms}} = 3 * \text{8.4e6} / \text{4.5e10} = 560\mu s$. 这里 是来自这个 Colab 的实际配置文件, 显示为 $680 \mu s$, 这是合理的, 因为我们可能没有获得 100% 的理论带宽! 对于第 (2) 部分, 每个分片的大小为 64 * 256 * 2 = 32kB. 32e3 / 4.5e10 = 0.7us, 所以我们受延迟限制. 由于我们有 3 次跳跃, 这大约需要 3 * 1us = 3us. 在实践中, 它更接近 8us.
第三种基本情况是当两个乘数都在它们的收缩维度上分片, 并且沿同一网格轴:
\[\textbf{A}[I, J_X] \cdot \textbf{B}[J_X, K] \rightarrow C[I, K]\]在这种情况下, 本地分片块矩阵乘法至少是可能执行的, 因为它们将共享相同的收缩索引集. 但是每个乘积只代表最终期望乘积的部分和, 并且沿 X 维度的每个设备将剩下这个最终期望乘积的不同部分和. 这种情况非常普遍, 以至于我们扩展了我们的符号来明确标记这种情况:
\[\textbf{A}[I, J_X] \cdot_\text{LOCAL} \textbf{B}[J_X, K] \rightarrow C[I, K] \{\ U_X \}\]符号 { UX } 读作“沿 X 网格轴未归约”, 指的是该操作在某种意义上是“未完成”的状态, 因为它只有在最终求和后才算完成. $\cdot_\text{LOCAL}$ 语法意味着我们执行本地求和, 但将结果保持未归约状态.
这可以看作是关于矩阵乘法和外积的以下结果:
\[A \cdot B = \sum_{i=1}^{P} \underbrace{A_{:,i} \otimes B_{i,:}}_{\in \mathbb{R}^{n \times m}}\]其中 ⊗ 是外积. 因此, 如果轴 X 上的 TPU i 具有 A 的第 i 列和 B 的第 i 行, 我们可以进行本地矩阵乘法以获得 \(A_{:,i} \otimes B_{i,:} \in \mathbb{R}_{n\times m}\). 这个矩阵的每个条目都包含 A • B 在该条目处的和的第 i 项. 我们仍然需要对 P 进行求和, 我们在网格轴 X 上对其进行了分片, 以获得完整的 A • B. 如果我们按块 (即分片) 写出 A 和 B, 然后对结果的每个分片求和, 这种方式同样有效.
我们可以使用跨 X 轴的完整 AllReduce 来解决这个问题:
\[\begin{align*} A[I, J_X] \cdot_\text{LOCAL} B[J_X, K] \rightarrow &\ C[I, K] \{ U_X \} \\ \textbf{AllReduce}_X C[I, K] \{ U_X \} \rightarrow &\ C[I, K] \end{align*}\]AllReduce 移除部分和, 导致沿该轴的每个设备都具有相同的完全求和的值. AllReduce 是我们将在本节中讨论的几个关键通信中的第二个, 第一个是 AllGather, 其他是 ReduceScatter 和 AllToAll. AllReduce 接受一个具有未归约 (部分求和) 轴的数组, 并通过在该未归约轴周围传递这些分片并累积结果来执行求和. 签名为
\[\textbf{AllReduce}_Y A[I_X, J] \{U_Y\} \rightarrow A[I_X, J]\]这意味着它只是移除了 $\{U_Y\}$ 后缀, 但在其他方面保持结果不变.
AllReduce 的成本是多少? 一个关于 AllReduce 如何执行的心智模型是, 每个设备都将其分片发送给其邻居, 并将接收到的所有分片相加. 显然, 这比 AllGather 更昂贵, 因为每个“分片”都与完整数组具有相同的形状. 通常, AllReduce 的成本是 AllGather 的两倍. 一种看待这个问题的方式是注意到 AllReduce 可以表示为另外两个原语的组合: 一个 ReduceScatter 和一个 AllGather. 与 AllReduce 一样, ReduceScatter 解析数组上的部分和, 但结果是沿给定维度“分散”或分区的输出. AllGather 收集所有这些片段, 并沿该物理轴“取消分区/取消分片/复制”逻辑轴.
\[\begin{align*} \textbf{ReduceScatter}_{Y,J} : A[I_X,J] \{U_Y\} \rightarrow &\ A[I_X, J_Y] \\ \textbf{AllGather}_Y : A[I_X, J_Y] \rightarrow &\ A[I_X, J] \end{align*}\]那么 ReduceScatter 呢? 正如 AllReduce 移除一个下标 ($F_Y \to F$ 上面), ReduceScatter 对一个未归约/部分求和的数组求和, 然后将另一个逻辑轴沿同一网格轴分散 (分片). $[F]\{U_Y\} \to [F_Y]$. 动画显示了这是如何完成的: 请注意, 它与 AllGather 非常相似, 但我们不是保留每个分片, 而是将它们相加. 因此, 它的延迟大致相同, 不包括执行归约所需的时间.
每个跳跃的通信时间就是每个分片的字节数 $V / Y$ 除以带宽 $W_\text{ici}$, 就像 AllGather 一样, 所以我们有
\[T_{\text{comms per AllGather or ReduceScatter}} = \frac{V}{W_\text{ici}}\] \[T_{\text{comms per AllReduce}} = 2 \cdot \frac{V}{W_\text{ici}}\]其中 \(W_\text{ici}\) 是双向带宽, 只要我们有一个完整的环可以进行归约.
每个网格维度在对张量进行分片时最多只能出现一次. 执行上述规则有时会导致违反此规则的情况, 例如:
\[A[I_X, J] \cdot B[J, K_X] \rightarrow C[I_X, K_X]\]这是无效的, 因为沿维度 X 的给定分片, 比如说 i, 将具有 C 的 (i, i) 分片, 即一个对角线条目. 那么, 在所有分片中没有足够的信息来恢复除对角线条目之外的任何内容, 所以我们不能允许这种分片.
解决这个问题的方法是对某些维度进行 AllGather. 在这里我们有两个选择:
\[\begin{align*} \textbf{AllGather}_X A[I_X, J] \rightarrow &\ A[I, J] \\ A[I, J] \cdot B[J, K_X] \rightarrow &\ C[I, K_X] \end{align*}\]或
\[\begin{align*} \textbf{AllGather}_X B[J, K_X] \rightarrow &\ B[J, K] \\ A[I_X, J] \cdot B[J, K] \rightarrow &\ C[I_X, K] \end{align*}\]在任何一种情况下, 结果在其形状中只会提到 X 一次. 我们选择哪一个将取决于后续操作需要什么样的分片.
前面的 4 种情况介绍了几种用于执行分片矩阵乘法的“核心通信原语”:
还有一种核心通信原语需要提及, 它出现在专家混合 (MoE) 模型和其他计算中: AllToAll.
最后一个基本的集合操作, 在考虑分片矩阵乘法时不会自然出现, 但在实践中经常出现, 是 AllToAll 集合, 或者更准确地说, 是分片转置或重分片操作的特殊情况. 例如
\[\textbf{AllToAll}_{X, J} A[I_X, J] \rightarrow A[I, J_X]\]AllToAll 通常需要重新排列分片计算的不同区域之间的分片布局, 这些区域没有兼容的布局方案. 在考虑分片专家混合模型时, 它们会自然出现. 你可以将 AllToAll 视为将一个下标从一个轴移动到另一个轴. 因为 all to all 不需要复制每个分片的所有数据到环上, 它实际上比 AllGather 便宜 (便宜 1/4)
如果我们推广到 ND AllToAll, 在 AxBxC 网格上, 一个 V 字节数组的总成本是
\[T_\text{comms per AllToAll} = \frac{V \cdot \max(A, B, C, ...)}{4 \cdot N \cdot W_\text{ici}}\]其中, 像往常一样, $W_\text{ici}$ 是双向 ICI 带宽. 对于 1D 网格, 这简化为 $V / (4 \cdot W_\text{ici})$, 这是 AllReduce 成本的 1/4. 在 2D 中, 成本实际上随着最小轴的大小而降低.
旁注: 如果你想要一个粗略的推导, 从一个 1D 环面 $\mathbb{Z} / N\mathbb{Z}$ 开始. 如果我们随机选择一个源节点和目标节点, 它们平均相距 N / 4 跳, 这给了我们一个成本 $(V \cdot N) / (4 * N)$. 现在如果我们考虑一个 ND 环面, 每个轴基本上是独立的. 每个节点有 $1 / Z$ 字节, 平均需要将其数据跳跃 $\max(A, B, C, …) / 4$ 跳.
ReduceScatter 是一个比它最初看起来更基本的操作, 因为它实际上是 AllGather 的导数, 反之亦然. 即, 如果在前向传播中我们有:
\[\textbf{AllGather}_X A[I_X] \rightarrow A[I]\]然后我们对反向模式导数 A’ (通常在每个分片上都不同) 进行 ReduceScatter, 以推导出分片的 A’:
\[\textbf{ReduceScatter}_X A'[I] \{ U_X \} \rightarrow A'[I_X]\]同样, 在前向传播中 \(\text{ReduceScatter}_X(A[I] \{U_X\}) \to A[I_X]\) 意味着在后向传播中 \(\text{AllGather}_{X}(A'[I_X]) \to A'[I]\).
将 AllReduce 转换为 AllGather 和 ReduceScatter 还有一个方便的特性, 即我们可以将最终的 AllGather 推迟到稍后的某个时刻. 我们通常不希望支付在设备上复制完整矩阵乘积的成本. 相反, 我们希望即使在这种组合两个具有分片收缩维度的乘数的情况下, 也能保持分片状态:
\[A[I, J_X] \cdot B[J_X, K] \rightarrow C[I, K_X]\]在这种情况下, 我们也可以执行 ReduceScatter 而不是 AllReduce, 然后可以选择在稍后的某个时间执行 AllGather, 即
\[\begin{align*} A[I, J_X] \cdot_{LOCAL} B[J_X, K] \rightarrow &\ C[I, K] \{ U_X \} \\ \textbf{ReduceScatter}_{X,K} C[I, K] \{ U_X \} \rightarrow &\ C[I, K_X] \end{align*}\]请注意, ReduceScatter 引入了一个分片维度, 因此在这种情况下, 沿着 I 或 K 命名维度进行分片具有天然的自由度. 在使用 ReduceScatter 时, 我们通常需要选择哪个命名维度来引入新的分片 (尽管选择通常由更大的建模上下文强制). 这就是为什么我们使用语法 ReduceScatterX,K 来指定要分片的轴.
使用分片数组的算术运算与使用未分片数组的算术运算完全相同, 除非你沿分片轴执行收缩. 在这种情况下, 我们必须引入一些通信. 我们考虑四种情况:
| 操作 | 描述 | 语法 | 运行时间 |
|---|---|---|---|
| AllGather | 收集分片数组沿一个轴的所有分片, 移除一个下标. | $[A_X, B] \to [A, B]$ | 字节 / (双向 ICI 带宽 * num_axes) |
| ReduceScatter | 对一个部分求和的数组沿一个轴求和, 并将其沿另一个轴分片 (添加一个下标). | $[A, B] \{U_X\} \to [A_X, B]$ | 与 AllGather 相同 |
| AllReduce | 对一个部分求和的数组沿一个轴求和. 移除一个 { Ux }. 结合了 AllGather 和 ReduceScatter. | $[A_X, B]\{U_Y\} \to [A_X, B]$ | 2 * AllGather |
| AllToAll | 收集 (复制) 一个轴, 并将另一个维度沿同一轴分片. | $[A, B_X] \to [A_X, B]$ | 双向环的 AllGather / 4 |
这里有一些基于本节内容的有启发性的问题. 我们目前不会包含所有答案, 但我们会尽可能多地写出答案.
问题 1 [复制分片]: 一个数组被分片为 $A[I_X, J, K, \ldots]$ (即, 仅在 $X$ 上分片), 网格为 Mesh({'X': 4, 'Y': 8, 'Z': 2}). 所有芯片上 $A$ 占用的总字节数与数组一个副本的大小之比是多少?
我们的数组只在 X 上分片, 大小为 4, 所以实际上每个分片的大小为 $[I / 4, J, K, \ldots] = \text{sizeof}(A) / 4$. 由于我们的数组在 Y 和 Z 上复制, 总大小为 $Y \cdot Z \cdot \text{sizeof}(A)$, 所以总大小与单个芯片大小之比为 $Y \cdot Z \cdot \text{sizeof}(A) / \text{sizeof}(A) = 16$.
问题 2 [AllGather 延迟]: 在 TPUv4p 4x4x4 切片上, 网格为 Mesh({'X': 4, 'Y': 4, 'Z': 4}), 如果 $B=1024$ 且 $D=4096$ (bfloat16), $\text{AllGather}_X([B_X, D_Y])$ 需要多长时间? \(\text{AllGather}_{XY}([B_X, D_Y])\) 呢? \(\text{AllReduce}_Z([B_X, D_Y] \{U_Z \})\) 呢?
我们在所有轴上都有一个环绕链接, 因为我们有一个完整的 4x4x4 立方体, 所以我们有 9e10 的双向带宽可用.
因为我们只是在一个轴上收集, 而另一个轴是分片的, 所以我们实际上是在 1 个轴上收集 $2BD / Y$ 字节. 如果你只考虑 Y 轴上的一个分片, X 轴上的 AllGather 看起来就像一个未分片的 AllGather, 字节数为 1 / Y. 由于我们的 TPU v4p 的 ICI 带宽是双向 9e10 字节/秒, 这将需要 $2BD / (\text{9e10} \cdot Y) = 2 \cdot 1024 \cdot 4096 / (\text{9e10} \cdot 4) = 23 \mu s$.
我们的带宽是以前的两倍, 但我们正在 AllGather 整个数组, 所以 T = 2BD / (2 * W) = 2*1024*4096 / (2 * 9e10) = 46us. 这远低于 4us 的延迟限制 (每跳 1us), 所以我们没问题.
AllReduce 的成本是 AllGather 的两倍. 每个分片的大小为 $2BD / (X * Y)$, 所以成本约为 $4BD / (X * Y * W)$, 或大约 4 * 1024 * 4096 / (16 * 9e10) = 11.6us.
问题 3 [受延迟限制的 AllGather]: 假设我们正在执行一个 $\text{AllGather}_X([B_X])$, 但 $B$ 非常小 (比如 128). 在 TPUv4p 4x4x4 切片上, 网格为 Mesh({'X': 4, 'Y': 4, 'Z': 4}), bfloat16 格式, 这需要多长时间? 提示: 你可能受延迟限制.
我们的 bfloat16 数组总共只使用 256 字节, 每个设备只有 64 字节. 由于我们在 TPU v4p 上有一个大小为 4 的轴, 我们有一个环绕链接, 所以我们可以双向发送数组. 单向带宽为 4.5e10, 每次跳跃大约需要 64 / 4.5e10 ~ 0, 所以我们肯定受延迟限制. 计算跳跃次数, 我们只需 2 次跳跃就可以完成整个收集, 所以大约 2us 是一个很好的估计.
问题 4 [矩阵乘法策略]: 为了执行 $X[B, D] \cdot_D Y[D_X, F] \to Z[B, F]$, 在本节中, 我们告诉你执行 $\text{AllGather}_X(Y[D_X, F])$ 并乘以完全复制的矩阵 (情况 2, 策略 1). 相反, 你可以像 $X[B, D_X] \cdot_D Y[D_X, F] \to Z[B, F] \{U_X\}$ (情况 4, 策略 2) 那样乘以本地分片, 然后 $\text{AllReduce}_X(Z[B, F] \{ U_X\})$. 这两种策略各执行多少 FLOPs 和通信? 哪种更好, 为什么?
让我们从我们的基线 (策略 1) 开始. 正如我们已经展示的, AllGather 的成本是 $2DF / W_\text{ici}$. 一旦我们有了完全复制的数组, 总计算时间是 $2BDF / C$ (其中 $C$ 是我们的加速器 FLOPs/s, 因为每个 TPU 执行相同的 FLOPs). 所以我们有
\[T_\text{total (策略 1)} = \max\left(\frac{2BDF}{C}, \frac{2DF}{W_\text{ici}}\right)\]相比之下, 新策略 (策略 2) 对 $2BF$ 字节进行 AllReduce, 成本为 $4BF / W_\text{ici}$, 但执行的 FLOPs 少 $1 / X$ (因为计算是分片的). 这意味着我们执行 $2\cdot B\cdot D\cdot F / X$ FLOPs, 并且由此产生的 AllReduce 通信 \(2 \cdot 2 \cdot B \cdot F\) 字节 (bfloat16). 因此, 策略 2 (没有 AllGather, 只是稍后进行 AllReduce) 的总时间大约是
\[T_\text{total} = \max\left(\frac{2BDF}{X \cdot C}, \frac{4BF}{W_\text{ici}}\right)\]问题是: 哪个更大? 当 $D / (X \cdot C) > 2 / W_\text{ici}$, 或当 $D / 2X > C / W_\text{ici} \approx 2550 \rightarrow X < D / (2 * 2550)$ 时, 策略 (2) 受计算限制. 我们可能合理地期望 $D \approx 8k$, 所以这意味着大约 $X < 2$, 这是不可能的 – 因此我们基本上总是使用策略 2 受通信限制. 使用基线 (策略 1), 当 \(B < C / W_\text{ici} = 2550\) 时, 我们受通信限制, 这通常是正确的, 但并非总是如此.
所以如果 $B < 2550$, 我们在两种情况下都受通信限制, 我们有
\[T_\text{comms for Strategy 2} < T_\text{comms for Strategy 1} \Leftrightarrow \frac{4BF}{W_\text{ici}} < \frac{2DF}{W_\text{ici}}\]当 $D > 2B$ 且 $2B < 5100$ 时, 这是正确的. 这通常是正确的, 所以如果我们的批量很小, 策略 2 有时会更好. 当我们的批量很大 ($B > 2550$) 时, 我们有
\[T_\text{comms for Strategy 2} < T_\text{math for Strategy 1} \Leftrightarrow \frac{4BF}{W_\text{ici}} < \frac{2BDF}{C}\]当 $2 / W_\text{ici} < D / C$, 或当 $D > 2 * 2550 = 5100$ 时, 这是正确的, 这对于大型模型通常是正确的. 所以这种替代策略对于大型模型通常更好, 除非 $D$ 很小.
我们为什么不总是这样做? 嗯, 在实践中我们有时可能会这样做, 但通常很少有一个矩阵乘法的输入的收缩维度在一个轴上分片, 而另一个输入没有在该轴上分片. 例如, 如果我们正在做 FSDP (在第 5 节中解释), 我们将在数据维度上分片我们的参数, 但我们的激活也将在数据维度上分片. 所以从这个意义上说, 这种情况不常出现.
问题 5 [最小延迟]: 假设我想在 TPUv5p 4x4x4 上以尽可能低的延迟进行矩阵乘法 $A[B, D] \cdot_D B[D, F] \to C[B, F]$. 我的输入应该如何分片? 总的 FLOPs 和通信时间是多少?
问题 6: 假设我们想在 TPUv5e 4x4 上执行 $A[I_X, J_Y] \cdot_J B[J_Y, K] \to C[I_X, K]$. 我们执行什么通信? 通信与计算花费的时间各是多少?
问题 7: 一个典型的 Transformer 块有两个矩阵 $B[D, F]$ 和 $C[F, D]$, 其中 $F \gg D$. 批量大小为 B, 整个块是 \(C \cdot B \cdot x\), 其中 \(x[B, D]\). 让我们选择 \(D=8192\), \(F=32768\), 和 \(B=128\), 并假设一切都是 bfloat16. 假设我们正在一个 TPUv5e 2x2 切片上运行, 但假设每个 TPU 只有 300MB 的可用内存. B, C 和输出应该如何分片以保持在内存限制以下, 同时最小化总时间? 通信和 FLOPs 花费的时间各是多少?
问题 8 [挑战]: 使用上面的简短代码片段作为模板, 分配一个分片数组, 并使用 pmap 或 shard_map 对 4 种主要通信原语 (AllGather, AllReduce, ReduceScatter 和 AllToAll) 进行基准测试. 你将需要使用 jax.lax.all_gather, jax.lax.psum, jax.lax.psum_scatter 和 jax.lax.all_to_all. 你理解这些函数的语义吗? 它们需要多长时间?
问题 9 [分片矩阵乘法的另一种策略?]: 上面 我们声称, 当只有一个矩阵乘法的输入沿其收缩维度分片时, 我们应该 AllGather 分片矩阵并在本地执行收缩. 你可能想到的另一种策略是执行分片矩阵乘法, 然后对结果进行 AllReduce (就好像两个输入都沿收缩维度分片一样), 即 $A[I, J_X] *_J B[J, K] \to C[I, K]$ 通过
回答以下问题:
M/K.问题 10: AllToAll 的乐趣: 在上表中, 注意到执行 AllToAll 的时间比执行 AllGather 或 ReduceScatter 的时间低 4 倍 (在我们受吞吐量限制的情况下). 在这个问题中, 我们将看到这 4 倍的来源, 并看到如果我们只有单向 ICI 链接, 而不是双向 ICI 链接, 这个因素会如何改变.
(1) 解决方案: 过程很简单: 在算法的每个步骤中, 每个设备都会将一个单分片“条带”的矩阵 (总共 \(\frac{N}{D} \times N\) 个元素) 发送给其最近的邻居. 这会发生 \(D-1\) 次, 因为每个分片都需要被通信到除其起始设备之外的所有设备. 所以总共, 每个设备传输 \(\frac{N^2(D-1)}{D}\) 个标量, 即流经单个 ICI 链接.
答案: \(N^2 (1-\frac{1}{D})\), 或者当 \(D >> 1\) 时, 简单地为 \(N^2\).
(2) 解决方案: 从通信的角度来看, AllToAll 和 AllGather 之间的关键区别在于, 在 AllToAll 中, 驻留在特定设备上的整个分片不需要被通信到每个其他设备. 想象一下存储在特定设备 (称之为设备 0) 上的分片是 \([A, B, C, D]\) (这里 A,B,C,D 是矩阵, 我们正在想象一个有 4 个设备的环来说明). 现在矩阵 \(A\) 不需要被通信到任何地方, 矩阵 \(B\) 需要最终到达设备 1; 矩阵 \(C\) 最终到达设备 2; 矩阵 \(D\) 最终到达设备 3. 所以在算法的第一步, 我们将 \(B\), \(C\), 和 \(D\) 发送到设备 1; 在下一步, 设备 1 将 \(C\) 和 \(D\) 继续发送到设备 2; 在最后一步, 设备 2 只将 \(D\) 发送到设备 3. 在这种情况下传输的参数总数是 \((\text{A/B/C/D 的大小}) * (3 + 2 + 1)\). A/B/C/D 的大小是 (在一般情况下) \(\frac{N^2}{D^2}\), 并且同样在一般情况下, \((3 + 2 + 1)\) 项变成 \(((D-1) + (D-2) + … + 1)\), 或 \(\frac{(D)(D-1)}{2}\). 所以单个 ICI 链接上传输的总字节数是 \(\frac{N^2(D-1)}{D \times 2}\).
答案: \(\frac{N^2}{2}(1-\frac{1}{D})\), 或者当 \(D >> 1\) 时, 简单地为 \(\frac{N^2}{2}\).
(3) 解决方案: 因子就是 \(\frac{1}{2}\), 即在单向环形拓扑上, AllToAll 的成本是 all-gather/ReduceScatter 的一半. 回顾上面的推导, 这最终来自于这样一个事实, 即在 all-gather 的情况下, 我们每次传输相同大小的块 \((D-1)\) 次, 即我们正在做求和 \(\text{小块大小} * (D + D + D + … + D)\), 而在 AllToAll 的情况下, 我们正在做求和 \(\text{小块大小} * (D + D-1 + D-2 + … + 1)\). 因此, 2 的因子本质上来自于这样一个事实, 即 \(1 + 2 + \ldots + n = n(n+1)/2\).
(4) 解决方案: 现在任何一个链接必须承载的总标量数量减少了 2 倍, 因为在双向环中, 每个“分片条带”可以同时双向发送.
(5) 解决方案: 在这种情况下, 与单向情况相比, 我们赢得了 4 倍的优势. 这最容易通过考虑单个分片条带中每个大小为 (N2/D2) 的块的命运来看出, 比如说源自设备 0 的那个. 与 (单向情况) 发送其中一个块距离 D-1, 另一个块距离 D - 2, 等等一直到 1 不同, 我们现在将条带分成向右或向左移动的块, 最大移动距离为 ceil(D/2). 所以相应的和现在变成 \(D/2 + D/2 - 1 + D/2 - 2 + … = D/2 \cdot (D/2+1)/2\), 或者在 \(D\) 很大时为 \(D^2/8\). 与单向情况下的 \(D^2/2\) 相比, 我们看到我们赢得了 4 倍的优势.
(6) 解决方案: 在单向环中, 我们看到 AllToAll 时间已经比 all-gather 时间快两倍; 这来自于我们不需要将我们的完整条带发送到每个设备的事实. 然后, 当我们添加双向性时, 我们看到对于 AllToAll 来说是 4 倍的胜利, 而对于 all-gathers 来说只有 2 倍的胜利. 将这些比率放在一起, 我们得到了我们所寻求的 4 倍因子.