Part 7 of How To Scale Your Model 中文版 (Part 6: Training LLaMA | Part 8: Serving LLaMA)
对 Transformer 进行推理可能与训练非常不同. 部分原因在于推理增加了一个需要考虑的新因素: 延迟. 在本节中, 我们将从模型中采样一个新 token 开始, 一直到作为推理引擎的一部分, 在多个加速器切片上高效扩展大型 Transformer.
所以你已经训练了一个 Transformer, 并且你想用它来生成一些新的序列. 归根结底, 基准分数上升和损失曲线下降只是衡量一旦付诸实践是否会发生有趣事情的代理指标!
采样在概念上很简单. 我们输入一个序列, 我们最喜欢的 Transformer 会吐出 \(\log p(\text{下一个 token}_i \vert \text{之前的 token})\), 即所有可能的下一个 token 的对数概率. 我们可以从这个分布中采样并获得一个新的 token. 附加这个 token 并重复这个过程, 我们就得到了一个作为提示延续的 token 序列.
我们刚刚描述了 Transformer 采样的朴素实现, 虽然它有效, 但我们在实践中从不这样做, 因为我们每次生成一个 token 都会重新处理整个序列. 这个算法在 FFW 上是 \(O(n^2)\), 在注意力机制上是 \(O(n^3)\), 以生成 \(n\) 个 token!
我们如何避免这种情况? 与其每次都进行完整的前向传播, 事实证明我们可以保存每次前向传播的一些中间激活, 从而避免重新处理之前的 token. 具体来说, 由于给定的 token 在点积注意力期间只关注之前的 token, 我们可以简单地将每个 token 的键和值投影写入一个名为 KV 缓存的新数据结构中. 一旦我们为过去的 token 保存了这些键/值投影, 未来的 token 就可以简单地计算它们的 \(q_i \cdot k_j\) 乘积, 而无需对早期的 token 执行任何新的 FLOPs. 太棒了!
考虑到这一点, 推理有两个关键部分:
<EOS> token 或达到某个最大长度限制.这是一个带 KV 缓存的采样图:
通过使用 KV 缓存进行采样, 我们将生成 $n$ 个 token 的时间复杂度在 FFW 上降低到 \(O(n)\), 在注意力上降低到 \(O(n^2)\), 因为我们从不重新处理之前的 token. 然而, 生成一个序列仍然需要许多次前向传播 —— 这就是当你查询 Gemini 或 ChatGPT 并且结果流式传输给你时发生的事情. 每个 token (通常) 都是对一个巨大模型的单独 (但部分缓存的) Transformer 调用.
我们很快就会看到, 预填充和生成是截然不同的野兽 —— Transformer 推理是伪装成两个任务! 与训练相比, KV 缓存也是一个新颖且重要的复杂性来源.
在继续之前, 值得强调一下推理中一个全新的方面: 延迟. 虽然在训练期间我们只关心吞吐量 (每秒每个芯片处理的总 token 数), 但在推理期间, 我们必须担心我们生成 token 的速度 (包括首个 token 的时间 (TTFT) 和每个 token 的延迟). 例如:
llama.cpp) 只需要以尽可能低的延迟一次为一个用户提供服务, 可能会有严格的硬件限制.最大化硬件利用率仍然至关重要, 有助于降低成本和 TTFT, 但与训练不同, 它不一定在所有情况下都能转化为更好的用户体验. 加速器, 系统和模型架构级别的许多优化都在延迟, 吞吐量, 上下文长度甚至模型质量之间进行权衡.
到目前为止, 我们主要将 Transformer 视为一堆前馈块. 虽然从 FLOPs 和内存的角度来看, 这通常是合理的, 但它不足以正确地建模推理.
在接下来的几节中, 我们将结合预填充和生成来看待这些操作, 并询问什么可能成为我们性能的瓶颈. 在单个加速器内, 我们是受计算限制还是受内存限制? 我们想强调一下, 对于预填充和生成, 答案会有多大的不同.
我们所有的线性操作在概念上都是相同的, 无论它们是在 MLP 块还是在注意力中. 它们的算术强度取决于批量大小. 我们在第 1 节中做过这个数学计算, 但值得重复一下. 让我们看一个 $\text{bf16[B, D]}$ 批次与一个 $\text{bf16[D, F]}$ 矩阵的单个矩阵乘法. 这可能是大的 MLP 块 ($W_\text{in}$ 或 $W_\text{out}$) 或较小的注意力投影之一 ($W_Q$, $W_K$, $W_V$, $W_O$). 为了进行这个矩阵乘法, 我们需要从 HBM 将这两个数组加载到 MXU 中, 进行乘法, 然后将结果写回 HBM. 和以前一样, 我们有:
\[T_\text{math} = \frac{\text{计算 FLOPs}}{\text{加速器 FLOPs/s}} = \frac{2BDF}{\text{加速器 FLOPs/s}}\] \[T_\text{comms} = \frac{\text{通信字节数}}{\text{带宽 字节/s}} = \frac{2BD + 2FD + 2BF}{\text{带宽 字节/s}}\]TPU 或 GPU 可以通过在进行计算时进行加载来重叠这些操作, 因此要受计算限制, 我们需要 \(T_\text{math} \geq T_\text{comms}\), 或者:
\[\frac{2BDF}{2BD + 2DF + 2BF} \geq \frac{\text{加速器 FLOPs/s}}{\text{带宽 字节/s}} \underset{\text{TPU v5e}}{=} \frac{1.97E+14}{8.20E+11} = 240\]其中 RHS 是我们硬件的算术强度. 现在让我们假设 $D$ 和 $F$ 相对于 $B$ 非常大 (通常我们的批次最多为 500, 而 $D$ 和 $F > 10k$), 我们可以通过使用 $\small{2BD + 2DF + 2BF \approxeq 2DF}$ 的事实来简化分母, 这给了我们
\(\begin{align*} \frac{2BDF}{2BD + 2DF + 2BF} \approxeq \frac{2BDF}{2DF} \geq \frac{\text{加速器 FLOPs/s}}{\text{带宽 字节/s}} \underset{\text{TPU v5e}}{=} \frac{1.97E+14}{8.20E+11} \implies B \geq 240 = B_\text{crit} \end{align*}\)$
如果我们对权重进行量化或对矩阵乘法使用较低精度的 FLOPs, 这个临界批量大小可能会改变. 例如, 如果我们将权重 量化为 int8 或 fp8, $B_\text{crit}$ 会减少 2 倍. 如果我们用 int8 或 fp8 进行 FLOPs, $B_\text{crit}$ 会增加 2 倍. 因此, 如果我们令 $\beta = \text{每个参数的位数} / \text{每个激活的位数}$, $\alpha_\text{hbm} = C / W_\text{hbm}$, 我们的临界批量大小实际上是 $B_\text{crit} = \beta \alpha_\text{hbm}$.
要点: Transformer 矩阵乘法是受计算限制的, 当且仅当每个副本的token批量大小大于 $B_\text{crit} = C / W_\text{hbm} \cdot (\text{每个参数的位数} / \text{每个激活的位数}) = \beta \cdot \alpha_\text{hbm}$. 对于 TPU v5e 上的 bf16 激活, 这是 240 个 token. 对于 H100, 大约是 280 个 token.
在训练期间, 我们在所有矩阵乘法中都会有很高的强度, 因为我们在一个非常大的批次上重复使用相同的权重. 这种高算术强度会延续到预填充, 因为用户提示通常有数百甚至数千个 token. 正如我们之前看到的, TPUv5e 的硬件算术强度是 240, 所以如果一个长度超过 240 个 token 的序列被输入到一个在 bf16 下运行的密集模型中, 我们会期望受计算限制, 一切都很好. 比这短的提示在技术上可以批处理在一起以实现更高的利用率, 但这通常不是必需的.
要点: 在预填充期间, 所有矩阵乘法基本上总是受计算限制. 因此, 简单地最大化硬件利用率或 MFU (模型 FLOPs 利用率) 就足以最大化每个芯片的吞吐量 (成本) 和延迟 (以 TTFT 的形式). 除非提示非常短, 否则在每个提示级别进行批处理只会增加延迟, 而对预填充吞吐量的改进很小.
然而, 在生成期间, 对于每个请求, 我们只能一次一个 token 地进行前向传播, 因为步骤之间存在顺序依赖关系! 因此, 我们只能 (容易地) 通过将多个请求批处理在一起, 在批处理维度上进行并行化来获得良好的利用率. 我们稍后会更多地讨论这个问题, 但实际上在不影响延迟的情况下将许多并发请求批处理在一起是很困难的. 因此, 用生成来饱和硬件 FLOPs 要困难得多.
要点: 在生成期间, 总 token 批量大小必须大于 $B_\text{crit}$ 才能在线性/前馈操作上受计算限制 (对于 TPU v5e 上的 bf16 参数为 240). 因为生成是串行地, 逐个 token 地发生的, 这要求我们将多个请求批处理在一起, 这很困难!
值得注意的是这有多大! 生成批量大小为 240 意味着 240 个并发请求同时生成, 以及 240 个单独的密集模型 KV 缓存. 这意味着这在实践中很难实现, 除了一些批量推理设置. 相比之下, 在预填充期间推送超过 240 个 token 是相当常规的, 尽管随着稀疏性的增加需要一些小心.
请注意, 这个确切的数字会因量化和硬件的类型而异. 加速器通常可以在较低的精度下提供更多的算术运算. 例如, 如果我们有 int8 参数但在 bf16 中进行计算, 临界批量大小会降至 120. 使用 int8 激活和 int8 参数, 它会跳回到 240, 因为 TPUv5e 可以提供 400 TOPs/s 的 int8 x int8.
当我们看点积注意力操作时, 事情变得更加复杂, 特别是由于我们必须考虑 KV 缓存. 让我们只看一个具有纯多头注意力的注意力头. 在单个 Flash Attention 融合中, 我们
把它们放在一起, 我们得到:
\[\text{多头注意力算术强度} = \frac{4BSTD}{4BSD + 4BTD} = \frac{ST}{S+T}\]对于预填充, $S=T$ 因为我们正在进行自注意力, 所以这简化为 $T^2 / 2T = T / 2$. 这很棒, 因为这意味着预填充期间注意力的算术强度是 $\Theta(T)$. 这意味着很容易受计算限制. 只要我们的序列长度相当大, 我们就没问题!
但是由于生成的序列维度很小, 并且 $B$ 和 $D$ 维度抵消了, 我们可以做近似:
\[S \gg T = 1 \implies \frac{ST}{S+T} \approx 1\]这很糟糕, 因为这意味着我们无法做任何事情来提高生成期间注意力的算术强度. 我们在加载一个巨大的 KV 缓存的同时只做了少量的 FLOPs. 所以我们在注意力期间基本上总是受内存带宽限制!
要点: 在预填充期间, 对于任何合理的序列长度 (大约 $\gt 480$ 个 token), 注意力通常受计算限制, 而在生成期间, 我们的算术强度低且恒定, 因此我们总是受内存带宽限制.
从概念上讲, 这是为什么? 主要地, 我们在模型的线性部分受计算限制, 因为参数 (内存带宽密集型组件) 被许多批处理项重用. 然而, 每个批处理项都有自己的 KV 缓存, 所以更大的批量大小意味着更多的 KV 缓存. 我们几乎总是会在这里受内存限制, 除非架构被积极地调整.
这也意味着一旦参数内存与 KV 缓存内存相当, 你将从增加批量大小中获得递减的吞吐量回报. 递减回报对你的伤害程度取决于单个序列的参数与 KV 缓存字节之比, 即大约是 $2DF / SHK$ 的比率. 由于 $HK\approx D$, 这大约取决于 $F$ 与 $S$ (序列长度) 的比率. 这也取决于使 KV 缓存更小的架构修改 (我们稍后会详细说明).
从这个数学中, 我们可以得到相当好的优化时应该瞄准的步骤时间的界限. (注意: 如果我们希望读者从本章中带走一件事, 那就是以下内容). 对于生成期间的小批量大小 (这很常见), 我们可以通过假设我们在注意力和 MLP 块中都受内存带宽限制来为我们的每步延迟设定下限:
\(\begin{equation*} \text{理论最小步骤时间} = \frac{\text{批量大小} \times \text{KV 缓存大小} + \text{参数大小}}{\text{总内存带宽}} \end{equation*}\)$
同样, 对于吞吐量:
\(\begin{equation*} \text{理论最大 token/秒} = \frac{\text{批量大小} \times \text{总内存带宽}}{\text{批量大小} \times \text{KV 缓存大小} + \text{参数大小}} \end{equation*}\)$
最终, 随着我们批量大小的增长, FLOPs 开始主导参数加载, 所以在实践中我们有更通用的方程:
\(\begin{align} \tiny \text{理论步骤时间 (通用)} = \underbrace{\frac{\text{批量大小} \times \text{KV 缓存大小}}{\tiny \text{总内存带宽}}}_{\text{注意力 (总是受带宽限制)}} + \underbrace{\max\left(\frac{2 \times \text{批量大小} \times \text{参数数量}}{\text{总 FLOPs/s}}, \frac{\text{参数大小}}{\text{总内存带宽}}\right)}_{\tiny \text{MLP (可以受计算限制)}} \end{align}\)$
其中注意力组件 (左) 从不受计算限制, 因此不需要 FLOPs 屋顶线. 这些对于粗略的计算相当有用, 例如
小测验: 假设我们想在 TPU v5e 4x4 切片上, 用 int8 和 bf16 FLOPs, 8192 上下文和 100 kB / token KV 缓存, 从一个 30B 参数的密集模型中生成一个批量大小为 4 个 token 的步骤. 这个操作的延迟的合理下限是多少? 如果我们想采样一个 256 个 token 的批次呢?
答案: 在 int8 中, 我们的参数将使用 30e9 字节, 并且根据给定的规格, 我们的 KV 缓存将每个使用 100e3 * 8192 = 819MB. 我们有 16 个芯片, 每个芯片有 8.1e11 字节/秒的带宽和 1.97e14 bf16 FLOPs/s. 从上面的方程中, 由于我们的批量大小很小, 我们预计我们的步骤时间至少为 (4 * 819e6 + 30e9) / (16 * 8.1e11) = 2.5 ms. 在 256 个 token 时, 我们将完全进入 MLP 块的计算限制状态, 所以我们的步骤时间大约是 (256 * 819e6) / (16 * 8.1e11) + (2 * 256 * 30e9) / (16 * 1.97e14) = 21ms.
正如你所看到的, 这里在吞吐量和延迟之间存在明显的权衡. 小批量速度快, 但不能很好地利用硬件. 大批量速度慢, 但效率高. 这是为一些较早的 PaLM 模型计算的延迟-吞吐量帕累托前沿 (来自ESTI 论文
我们不仅用批量大小作为旋钮来权衡延迟和吞吐量, 如果我们发现自己受 HBM 限制, 我们也可能更喜欢更大的拓扑而不是更小的拓扑, 以便容纳更大的批次. 下一节将更详细地探讨这一点.
要点: 如果你关心生成吞吐量, 请使用尽可能大的每个芯片的批量大小. 任何高于 TPU 算术强度 ($B_\text{crit}$, 通常为 120 或 240) 的每个芯片的批量大小都将最大化吞吐量. 你可能需要增加你的拓扑来实现这一点. 较小的批量大小将允许你以吞吐量为代价来提高延迟.
这一切都相当理论化. 在实践中, 我们通常不会看到一个清晰的屋顶线, 原因有几个:
我们花了一些时间来看带宽和 FLOPs, 但没有看内存. 在推理时, 内存情况看起来大不相同, 这要归功于我们的新数据结构, KV 缓存. 对于本节, 让我们选择一个真实模型 (LLaMA 2-13B) 来演示情况有多么不同:
| 超参数 | 值 |
|---|---|
| L (层数) | 40 |
| D (d_model) | 5,120 |
| F (ffw_dimension) | 13,824 |
| N (头数) | 40 |
| K (num_kv_heads) | 40 |
| H (qkv_dim) | 128 |
| V (num_embeddings) | 32,000 |
在推理期间什么在使用内存? 嗯, 显然, 是我们的参数. 计算这些, 我们有:
| 参数 | 公式 | 大小 (字节) |
|---|---|---|
| FFW 参数 | d_model2 x ffw_multiplier x 3 (用于 gelu + 输出投影) x n_layers | 5,120 x 5,120 x 2.7 x 3 x 40 = 8.5e9 |
| 词汇表参数 | 2 (输入和输出嵌入) x n_embeddings x d_model | 2 x 32,000 x 5,120 = 0.3e9 |
| 注意力参数 | [2 (q 和输出) x d_model x n_heads x d_qkv + 2 (k 和 v) x d_model x n_kv_heads x d_qkv] x n_layers | (2 x 5,120 x 40 x 128 + 2 x 5,120 x 40 x 128) x 40 = 4.2e9 |
将这些参数加起来, 我们得到 8.5e9 + 4.2e9 + 0.3e9 = 13e9 总参数, 正如预期的那样. 正如我们在前面的章节中看到的, 在训练期间, 我们可能会将参数存储在 bfloat16 中, 优化器状态存储在 float32 中. 这可能会使用大约 100GB 的内存. 这与我们的梯度检查点相比相形见绌, 后者可以使用几个 TB.
推理有什么不同? 在推理期间, 我们存储一个参数的副本, 比如说在 bfloat16 中. 这使用了 26GB —— 在实践中, 我们通常可以通过量化做得更好. 没有优化器状态或梯度需要跟踪. 因为我们不进行检查点 (为后向传播保留激活), 我们的激活占用空间对于预填充8,192 x 5,120 x 2 字节 = 80MB 的内存. 更长的预填充可以分解成许多更小的前向传播, 所以对于更长的上下文也不是问题. 生成使用的 token 比这还少, 所以激活可以忽略不计.
主要区别在于 KV 缓存. 这些是所有过去 token 的键和值投影, 大小仅受最大允许序列长度的限制. \(T\) 个 token 的总大小是
\[\text{KV 缓存大小} = 2 \cdot \text{每个浮点数的字节数} \cdot H \cdot K \cdot L \cdot T\]其中 \(H\) 是每个头的维度, \(K\) 是 KV 头的数量, \(L\) 是层数, 2 来自于同时存储键和值.
这会很快变得很大, 即使批量大小和上下文长度适中. 对于 LLaMA-13B, 一个 8192 序列在 bf16 下的 KV 缓存是
\(8192\ (T) \times 40\ (K) \times 128\ (H) \times 40\ (L) \times 2\ (\text{字节}) \times 2 = 6.7 \text{GB}\)$
仅仅 4 个这样的缓存就超过了我们参数的内存使用量! 需要明确的是, LLaMA 2 在较长上下文下的 KV 缓存大小方面没有进行优化 (情况并不总是这么糟, 因为通常 $K$ 要小得多, 就像 LLaMA-3 中那样), 但这仍然具有说明性. 我们在内存或延迟估计中不能忽略这些.
让我们看看如果我们尝试在 8xTPU v5es 上以不同的批量大小完美高效地执行生成, 直到达到前面为最大理论吞吐量推导出的临界批量大小 (240), 会发生什么.
| 批量大小 | 1 | 8 | 16 | 32 | 64 | 240 |
|---|---|---|---|---|---|---|
| KV 缓存内存 (GiB) | 6.7 | 53.6 | 107.2 | 214.4 | 428.8 | 1608 |
| 总内存 (GiB) | 32.7 | 79.6 | 133.2 | 240.4 | 454.8 | 1634 |
| 理论步骤时间 (ms) | 4.98 | 12.13 | 20.30 | 36.65 | 69.33 | 249.09 |
| 理论吞吐量 (token/秒) | 200.61 | 659.30 | 787.99 | 873.21 | 923.13 | 963.53 |
8x TPU v5es 给了我们 128GiB 的 HBM, 6.5TiB/s 的 HBM 带宽 (每个 0.82TiB/s) 和 1600TF/s 的计算能力.
对于这个模型, 增加批量大小确实能给我们带来更好的吞吐量, 但我们很快就会遇到收益递减. 我们在批量大小超过 16 时就会 OOM, 并且需要数量级更多的内存才能接近 240. 更大的拓扑可以改善延迟, 但我们在每个芯片的吞吐量上遇到了瓶颈.
假设我们保持总参数数量不变, 但神奇地将 KV 缓存缩小 5 倍 (比如说, 使用 1:5 的 GMQA, 这意味着我们有 8 个 KV 头在 40 个 Q 头上共享 —— 更多细节见下一节).
| 批量大小 | 1 | 8 | 16 | 32 | 64 | 240 |
|---|---|---|---|---|---|---|
| KV 缓存内存 (GiB) | 1.34 | 10.72 | 21.44 | 42.88 | 85.76 | 321.6 |
| 总内存 (GiB) | 27.34 | 36.72 | 47.44 | 68.88 | 111.76 | 347.6 |
| 理论步骤时间 (ms) | 4.17 | 5.60 | 7.23 | 10.50 | 17.04 | 52.99 |
| 理论吞吐量 (token/秒) | 239.94 | 1,429.19 | 2,212.48 | 3,047.62 | 3,756.62 | 4,529.34 |
使用更小的 KV 缓存, 我们仍然有递减的回报, 但理论上每个芯片的吞吐量会持续扩展到批量大小 240. 我们可以容纳更大的批量 64, 并且在所有批量大小下延迟也始终更好. 延迟, 最大吞吐量和最大批量大小都得到了显著改善! 事实上, 后来的 LLaMA 代使用了这个确切的优化 —— LLaMA-3 8B 有 32 个查询头和 8 个 KV 头 (来源).
要点: 除了参数, KV 缓存的大小对模型的最终推理性能有很大影响. 我们希望通过架构决策和运行时优化的组合来控制它.
自最初的 Attention is All You Need 论文以来, 已经开发了许多技术来使模型更高效, 通常专门针对 KV 缓存. 一般来说, 更小的 KV 缓存使得在不影响延迟的情况下更容易增加生成步骤的批量大小和上下文长度, 并使 Transformer 周围的系统 (如请求缓存) 的生活更轻松. 忽略对质量的影响, 我们可能会看到:
分组多查询注意力 (又名 GMQA, GQA): 我们可以减少 KV 头的数量, 并在注意力机制中与许多 Q 头共享它们. 在极端情况下, 可以在所有 Q 头之间共享单个 KV 头. 这将 KV 缓存减少了 Q:KV 比率的倍数, 相对于纯 MHA, 并且已经观察到模型的性能对这种变化相对不敏感.
这实际上也增加了注意力计算的算术强度 (参见第 4 节中的问题 4).
混合一些局部注意力层: 局部注意力将上下文限制在一个小到中等大小的最大长度. 在训练时间和预填充时间, 这涉及到将注意力矩阵掩码为一个对角条带, 而不是一个三角形. 这有效地限制了局部层的 KV 缓存的最大长度. 通过将一些局部层与一些全局层混合到模型中, 在上下文长度超过局部窗口时, KV 缓存的大小会大大减小.
跨层共享 KV: 模型可以学习以某种模式跨层共享相同的 KV 缓存. 虽然这确实减小了 KV 缓存的大小, 并在增加批量大小, 缓存, 离线存储等方面提供了好处, 但共享的 KV 缓存可能需要多次从 HBM 读取, 因此它不一定能改善步骤时间.
量化: 推理通常对参数和 KV 的精度不太敏感. 通过对参数和 KV 缓存进行量化 (例如, int8, int4, fp8 等), 我们可以节省两者的内存带宽, 减少达到计算屋顶线所需的批量大小, 并节省内存以在更大的批量大小下运行. 量化还有一个额外的好处, 即即使模型没有用量化进行训练, 也通常可以在训练后应用.
使用不规则 HBM 读取和分页注意力: 我们在上面的计算中为每个 KV 缓存分配了 8k 的上下文, 但通常不需要从内存中读取整个 KV 缓存 —— 请求具有广泛的长度分布, 并且不使用模型的最大上下文, 因此我们通常可以实现仅读取 KV 缓存的非填充部分的内核 (例如, Flash Attention 变体).
分页注意力
大局: 总而言之, 与标准 MHA Transformer 相比, 这些 KV 缓存优化可以将 KV 缓存大小减少一个数量级以上. 这可以导致 Transformer 的总成本提高一个数量级.
到目前为止, 我们一直在含糊地讨论如何扩展到单个芯片之外. 遵循第 5 节, 让我们探讨可用的不同策略及其权衡. 与往常一样, 我们将分别看待预填充和生成.
从屋顶线的角度来看, 预填充几乎与训练相同, 几乎所有相同的技术和权衡都适用 —— 模型 (Megatron) 并行, 序列分片 (对于足够长的上下文), 流水线, 甚至 FSDP 都是可行的! 你只需要保留 KV, 以便稍后可以进行生成. 与训练一样, 增加芯片数量可以让我们获得更多的 FLOPs/s (可能降低 TTFT), 但会增加通信开销 (可能降低每个芯片的吞吐量).
分片预填充的一般规则: 这是预填充的一般规则集. 我们假设我们只在一个序列上进行预填充 (没有批处理维度):
要点: 在预填充期间, 几乎任何在训练期间可行的分片都可以正常工作. 进行模型并行直到 ICI 限制, 然后进行序列并行.
生成比预填充更复杂. 一方面, 获得大的批量大小更难, 因为我们需要将许多请求批处理在一起. 延迟目标更低. 总之, 这意味着我们通常更受内存限制, 对通信开销更敏感, 这限制了我们的分片策略:
FSDP 是不可能的: 由于我们在从 HBM 加载参数和 KV 缓存到 MXU 时受内存限制, 我们不希望通过比 HBM 慢几个数量级的 ICI 来移动它们. 我们希望移动激活而不是权重. 这意味着类似于 FSDP 的方法通常对于生成是完全不可行的.
没有理由进行数据并行: 纯数据并行没有帮助, 因为它复制了我们的参数, 并且无助于我们更快地加载参数. 你最好启动多个模型的副本.
没有序列 = 没有序列分片. 祝你好运, 序列分片.
这主要给我们留下了密集模型生成的模型分片变体. 与预填充一样, 我们可以做的最简单的事情是简单的模型并行 (激活完全复制, 权重在 MLP 的隐藏维度上完全分片), 直到我们受 ICI 限制的 4-8 路. 然而, 由于我们经常受内存带宽限制, 我们实际上可以超越这个限制来提高延迟!
关于生成 ICI 限制的说明: 在训练期间, 我们希望受计算限制, 所以我们的屋顶线着眼于我们的 ICI 通信何时比我们的 FLOPs 花费更长的时间. 然而, 在生成期间, 如果我们受参数加载的内存带宽限制, 我们可以将模型分片增加到超过这个点, 并以最小的吞吐量成本 (以 token/秒/芯片为单位) 提高延迟. 更多的模型分片给了我们更多的 HBM 来加载我们的权重, 我们的 FLOPs 并不重要.
\(\begin{align*}T_\text{HBM comms} = \frac{2DF}{Y \cdot W_\text{hbm}} && T_\text{ICI comms} = \frac{2BD}{W_\text{ici}} \end{align*}\)$
\[T_\text{ICI comms} > T_\text{HBM comms} \rightarrow \frac{W_\text{hbm}}{W_\text{ici}} > \frac{F}{Y \cdot B} \rightarrow Y > F / (B \cdot \beta)\]其中 $\beta = W_\text{hbm} / W_\text{ici}$. 对于 TPU v5e 和 TPU v6e, 这个数字通常在 8 左右. 这意味着例如, 如果 $F$ 是 16,384, $B$ 是 32, 我们理论上可以进行模型并行, 直到 16384 / (32 * 8) = 64 路, 而不会对吞吐量产生有意义的影响. 这假设我们可以完全将我们的 KV 缓存分片 64 路, 这很困难: 我们将在下面讨论这个问题.
对于注意力层, 我们也对注意力 \(W_Q\) 和 \(W_O\) 在头 Megatron 风格上进行模型分片. KV 权重相当小, 复制它们通常比分片超过 $K$ 路分片更便宜.
要点: 我们在生成期间唯一的选择是模型并行的变体. 我们的目标是移动激活而不是 KV 缓存或参数, 它们更大. 当我们的批量大小很大时, 我们进行模型并行直到 FLOPs-ICI 限制 ($F / \alpha$). 当我们的批量大小较小时, 我们可以通过更多的模型分片来提高延迟 (以适度的吞吐量成本). 当我们想要模型分片的数量超过我们拥有的 KV 头时, 我们也可以在批处理维度上对我们的 KV 进行分片.
我们还有一个需要分片的新数据结构 —— KV 缓存. 同样, 我们几乎总是倾向于避免复制缓存, 因为它是注意力延迟的主要来源. 为此, 我们首先在头维度上对 KV 进行 Megatron 分片. 这仅限于 $K$ 路分片, 因此对于头数较少的模型, 我们尽可能多地对头维度进行分片, 然后在批处理维度上进行分片, 即 $\text{KV}[2, B_Z, S, K_Y, H]$. 这意味着 KV 缓存是完全分布式的.
这样做的代价是每个注意力层需要两次 AllToAll —— 一次是将 Q 激活转移到批处理分片, 以便我们可以用批处理分片计算注意力, 另一次是将批处理分片的注意力输出转移回纯模型分片.
在这里, 我们将写出在 $Y$ 和 $Z$ 上都进行模型并行的完整注意力算法. 我为同时使用 $K$ 表示键张量和 KV 头维度而道歉. 设 $M=N/K$.
这相当复杂, 但你大致可以看到它是如何工作的. 新的通信成本适中, 因为它们作用于我们的小激活, 作为回报, 我们节省了大量的内存带宽来加载 KV (它们是静止的).
到目前为止, 我们已经研究了如何独立地优化和分片单个预填充和生成操作. 为了有效地使用它们, 我们需要设计一个推理引擎, 它可以将这两个操作馈送到我们选择的延迟/吞吐量帕累托前沿上的一个点.
最简单的方法是简单地运行一批预填充, 然后运行一批生成:
这很容易实现, 并且是大多数代码库中的第一个推理设置, 但它有多个缺点:
因此, 这种方法只推荐用于边缘应用 (通常只关心为一个用户提供服务并使用具有较少 FLOPs/字节的硬件) 和 Transformer 代码库生命周期早期的快速迭代 (由于其简单性).
一个稍微好一点的方法是在批量大小为 1 时执行预填充 (此时它受计算限制但具有合理的延迟), 但在生成期间将多个请求批处理在一起:
这将避免批量预填充造成的 TTFT 浪费, 同时保持较高的生成吞吐量. 我们称之为交错配置, 因为我们“交错”了预填充和生成步骤. 这对于批量生成应用 (如评估) 非常强大, 在这些应用中, 吞吐量是主要目标. 编排器可以配置为在任何生成槽位空闲时优先进行预填充, 即使对于非常大的生成批量大小, 也能确保高利用率. 我们还可以避免将预填充填充到最大长度, 因为它没有与其他请求批处理.
主要缺点是, 当服务器执行预填充时, 所有其他请求的生成都会暂停, 因为所有计算资源都将被预填充消耗. 正在解码其响应的用户 A 将被正在进行预填充的用户 B 阻塞. 这意味着即使 TTFT 得到了改善, token 生成平均也会抖动和缓慢, 这对于许多应用来说不是一个好的用户体验 —— 其他用户的预填充处于请求总延迟的关键路径上.
为了解决这个问题, 我们将解码和预填充分开. 虽然 Transformer 推理可以在一台服务器上完成, 但从延迟的角度来看, 在两组 TPU/GPU 上执行这两个不同的任务通常更好. 预填充服务器生成 KV 缓存, 这些缓存通过网络发送到生成服务器, 生成服务器将多个缓存批处理在一起并为每个缓存生成 token. 我们称之为“分离式”服务.
这提供了一些优势:
大规模低延迟: 用户的请求永远不会阻塞在另一个用户的请求上, 除非预填充容量不足. 请求应该立即被预填充, 然后发送到生成服务器, 然后立即插入到生成缓冲区中. 如果我们预计会有许多并发请求进来, 我们可以独立于生成服务器的数量来扩展预填充服务器的数量, 这样用户就不会长时间停留在预填充队列中.
专业化: 通常, 预填充和生成的延迟最优参数分片策略/硬件拓扑非常不同 (例如, 更多的模型并行对生成有用, 但对预填充没用). 将这两个操作限制为使用相同的分片会损害两者的性能, 并且拥有两套权重会占用内存. 此外, 通过将预填充移动到自己的服务器上, 它不需要持有任何 KV 缓存, 除了它当前正在处理的那个. 这意味着我们有更多的内存可用于历史缓存 (见下一节) 或优化预填充延迟.
一个缺点是 KV 缓存现在需要在网络上传输. 这通常是可以接受的, 但再次为减小 KV 缓存大小提供了动力.
要点: 对于延迟敏感, 高吞吐量的服务, 我们通常必须将预填充和生成分离到不同的服务器上, 预填充以批量 1 运行, 生成则将许多并发请求批处理在一起.
上面的问题 (2) 激发了连续批处理的概念. 我们优化并编译:
然后, 我们将这些函数与一个编排器结合起来, 该编排器对传入的请求进行排队, 根据可用的生成槽位调用预填充和生成, 处理历史缓存 (见下一节) 并流式传输 token.
由于预填充成本高昂且受计算限制 (给了我们更少的余地), 减少其成本的最佳方法之一是减少预填充的次数. 因为 LLM 是自回归的, 查询 [“我”, “喜欢”, “狗”] 和 [“我”, “喜欢”, “猫”] 产生的 KV 缓存的前两个 token 是相同的. 这意味着, 原则上, 如果我们先计算“我喜欢狗”的缓存, 然后再计算“我喜欢猫”的缓存, 我们只需要进行 1/3 的计算. 我们可以通过重用缓存来节省大部分工作. 这在一些特定情况下特别强大:
这之所以难以做到, 唯一的限制是内存. 正如我们所见, KV 缓存很大 (通常是几个 GB), 为了使缓存有用, 我们需要将它们保留到后续查询到达. 通常, 预填充服务器上任何未使用的 HBM 都可以用于本地缓存系统. 此外, 加速器通常在其 CPU 主机上有很多内存 (例如, 一个 8xTPUv5e 服务器有 128GiB 的 HBM, 但大约有 450GiB 的主机 DRAM). 这个内存比 HBM 慢得多 —— 通常慢到无法进行生成步骤 —— 但对于缓存读取来说足够快. 在实践中:
谷歌开源了一个实现此逻辑的库, 名为 JetStream. 该服务器有一组“预填充引擎”和“生成引擎”, 通常在不同的 TPU 切片上, 由单个控制器进行编排. 预填充发生在“预填充线程”中, 而生成发生在“生成线程”中. 我们还有一个“传输线程”, 负责编排从预填充切片到生成切片的 KV 缓存复制.
Engine 接口 (在这里实现) 是任何 LLM 都必须提供的通用接口. 关键方法是:
我们还有一个 PyTorch 版本的 JetStream, 可在这里找到.
我将为本节发明一个基于 LLaMA-2 13B 的新模型. 以下是详细信息:
| 超参数 | 值 |
|---|---|
| L (层数) | 64 |
| D (d_model) | 4,096 |
| F (ffw_dimension) | 16,384 |
| N (头数) | 32 |
| K (num_kv_heads) | 8 |
| H (qkv_dim) | 256 |
| V (num_embeddings) | 32,128 |
问题 1: 上述模型有多少参数? 每个 token 的 KV 缓存 (int8) 有多大? 你可以假设我们共享输入和输出投影矩阵.
参数数量:
因此, 我们的总参数数量是 $L * D * (3F + 2H * (N + K)) + D * V$. 代入上面的数字, 我们有 64 * 4096 * (3*16384 + 2 * 256 * (32 + 8)) + 4096 * 32128 = 18.4e9. 因此, 这个模型大约有 184 亿个参数.
KV 缓存是每个 token $L * K * H$, 即 64 * 8 * 256 = 131kB.
问题 2: 假设我们想在 TPUv5e 4x4 切片上服务这个模型, 并且可以完全在这个拓扑上分片我们的 KV 缓存. 假设我们对所有东西都使用 int8, 并且想要支持 128k 序列, 我们可以容纳的最大批量大小是多少? 如果我们将 KV 头的数量减少到 1 呢?
我们的 KV 缓存每个 token 在 int8 中有 $L * K * H$ 的大小, 或 64 * 8 * 256 = 131kB. 对于 128k 序列, 这意味着每个批处理条目 131e3 * 128e3 = 16.8GB. 由于每个 TPU 有 16GB 的 HBM, 包括我们的参数, 我们可以容纳的最大批量大小是 (16 * 16e9 - 18.4e9) / 16.8e9 = 14. 如果我们有 $K=1$, 我们将有这个的 8 倍, 即大约 112.
问题 3: 假设参数在 TPU v5e 4x4 切片上完全分片, 将所有参数从 HBM 加载到 MXU 需要多长时间? 假设 int8 参数. 这是每步延迟的一个很好的下限.
我们总共有 18.4B 个参数, 或 18.4e9 字节 (int8). 我们每个芯片有 8.1e11 HBM 带宽, 所以大约需要 18e9 / (8.1e11 * 16) = 1.3ms, 假设我们可以完全利用我们的 HBM 带宽.
问题 4: 假设我们想在 TPUv5e 4x4 切片上使用 int8 FLOPs 和参数/激活来服务这个模型. 我们将如何为预填充和解码进行分片? 提示: 也许先回答这些问题:
对于这个分片, 生成的粗略每步延迟是多少?
问题 5: 让我们假装上面的模型实际上是一个 MoE. MoE 模型实际上是一个具有 E 个 FFW 块副本的密集模型. 每个 token 通过 k 个 FFW 块, 这 k 个块被平均以产生输出. 让我们使用 E=16 和 k=2 以及上面的设置.
(1) 作为一个 MoE, 每个 MLP 块现在有 $3 * E * D * F$ 个参数, 比密集变体增加了 $E$ 倍. 因此, 它现在总共有 $L * D * (3EF + 2H * (N + K)) + D * V$ 或 64 * 4096 * (3*16*16384 + 2 * 256 * (32 + 8)) + 4096 * 32128 = 212e9 个参数, 增加了约 12 倍. 对于激活的参数, 我们有 $k$ 而不是 $E$ 个激活的参数, 总共 64 * 4096 * (3*2*16384 + 2 * 256 * (32 + 8)) + 4096 * 32128 = 31.2e9, 比密集变体增加了不到 2 倍.
(2) 因为我们有 $E$ 倍的参数, 但只有 $k$ 倍的 FLOPs, 我们的 HBM 屋顶线增加了 $E/k$ 倍. 这意味着在 TPU v5e 上, 我们需要大约 240 * (16 / 2) = 1920 个 token.
(3) KV 缓存大小保持不变, 因为 MoE 特性不会改变注意力机制的任何东西.
(4) 这仍然是 $2ND$, 其中 $D$ 是激活的参数数量. 因此这是 $2 * \text{32.2e9} * T$.
问题 6: 对于 MoE, 我们可以进行“专家分片”, 即我们将我们的专家分布在我们网格的一个轴上. 在我们的标准符号中, 我们的第一个 FFW 权重的形状为 [E, D, F], 我们将其分片为 [EZ, DX, FY], 其中 X 仅在训练期间用作我们的 FSDP 维度. 假设我们想在 TPU v5e 上进行推理:
问题 7 [2D 模型分片]: 在这里, 我们将详细讲解 ESTI 论文所谓的 2D 权重固定分片的数学原理. 我们在附录 B 中简要描述了这一点, 但请先尝试做这个问题, 看看你是否能算出数学原理. 2D 权重固定分片的基本思想是沿 $D$ 和 $F$ 轴对我们的权重进行分片, 以便每个块大致呈方形. 这减少了通信负载, 并允许我们进一步扩展.
这是 2D 权重固定的算法:
你的目标是为这个算法计算出 $T_\text{math}$ 和 $T_\text{comms}$, 并找出它何时会优于传统的 3D 模型分片?
让我们计算 $T_\text{math}$ 和 $T_\text{comms}$. 我们所有的 FLOPs 都是完全分片的, 所以和以前一样, 我们有 $T_\text{math} = 4BDF / (N \cdot C)$, 但我们的通信现在是
\(\begin{align*} T_\text{2D comms} = \frac{2BD}{2X \cdot W_\text{ici}} + \frac{4BF}{YZ \cdot W_\text{ici}} + \frac{2BD}{2X \cdot W_\text{ici}} = \frac{2BD}{X \cdot W_\text{ici}} + \frac{4BF}{YZ \cdot W_\text{ici}} \end{align*}\)$
其中我们注意到 AllReduce 的成本是两倍, 并且我们根据每个操作执行的轴数来缩放我们的通信. 假设我们可以自由选择我们的拓扑, 并且假设 $F=4D$ (如 LLaMA-2), 我们声称 (通过一些基本的微积分) $X$, $Y$ 和 $Z$ 的最佳值是 $X = \sqrt{N / 8}$, $YZ = \sqrt{8N}$, 所以总通信是
\[T_\text{2D comms} = \frac{2B}{W_\text{ici}} \left(\frac{D}{X} + \frac{8D}{YZ}\right) = \frac{\sqrt{128} BD}{\sqrt{N} \cdot W_\text{ici}} \approx \frac{11.3 BD}{\sqrt{N} \cdot W_\text{ici}}\]首先, 从上面复制, 普通的 1D 模型并行将有 $T_\text{model parallel comms} = 4BD / (3 \cdot W_\text{ici})$, 那么新的通信何时更小? 我们有
\(\begin{align*} T_\text{model parallel comms} > T_\text{2D comms} \iff \frac{4BD}{3 \cdot W_\text{ici}} > \frac{\sqrt{128} BD}{\sqrt{N} \cdot W_\text{ici}} \iff N > 128 \cdot \left(\frac{3}{4}\right)^2 = 81 \end{align*}\)$
对于一个通用的 $F$, 我们声称这个条件是
\[N > 32 \cdot \left(\frac{F}{D}\right) \cdot \left(\frac{3}{4}\right)^2\]所以这告诉我们, 如果我们有超过 81 个芯片, 我们最好使用这个新方案. 现在这是一个有点奇怪的结果, 因为我们历史上发现自己在 ~20 路张量并行时受 ICI 限制. 但在这里, 即使我们受通信限制, 我们的总通信也会随着总芯片数量的增加而持续减少! 这告诉我们, 我们可以继续增加我们的芯片, 增加我们的批量大小, 进行更多的参数扩展, 并看到延迟降低.
我们上面提供的简单规则, 即我们的批量大小必须大于 240 个 token 才能受计算限制, 大致是正确的, 但忽略了 TPU 在其他操作不使用所有可用 HBM 时预取权重的一些能力, 比如在进行设备间通信时.
这是一个小型 Transformer 的层时间 (以微秒为单位) 的实证图, dmodel 为 8192, dff 为 32768, 每层只有 2 个矩阵乘法. 这来自这个 Colab 笔记本. 你会看到步骤时间增长非常缓慢, 直到大约批量 240, 然后线性增长.
这是实际的吞吐量, 以 token/us 为单位. 这很清楚地说明了论点. 由于我们的层大约是 600M 参数, 在这里分片 4 路, 我们预计最小延迟大约是 365us.
所以至少在这个模型中, 我们确实看到吞吐量增加, 直到每个数据并行分片大约 BS240.
随着拓扑的增长, 如果我们可以访问更高维度的网格 (如 TPU 的网格), 就可以通过“2D 权重分片”进一步完善. 通过引入第二个分片轴. 我们称之为“2D 权重固定”, 在高效扩展 Transformer 推理论文中有更详细的描述.
因为我们只在 Megatron 中对隐藏的 \(F\) 维度进行分片, 所以随着芯片数量随着 1D 分片而增长, 它可能会变得比 \(E\) ( \(d_\text{model}\) 维度) 小得多. 这意味着在更大的批量大小时, 在应用 MLP 的第一层后, 在隐藏维度上执行一部分集合操作可能更经济.
该图显示:
对于注意力层, Megatron 风格的分片对于较少数量的芯片也相对简单. 然而, Megatron 发生在 \(n_\text{heads}\) 维度上, 这对可能的分片数量设置了限制. 通过修改 2D 分片 (不是分片隐藏维度, 而是分片 \(n_\text{heads}\) 维度), 我们获得了进一步扩展的能力.
作为回顾, 在第 3 节中, 我们推导了在 1D 环形链路上, 在每个 TPU 上对大小为 B 的张量执行 AllGather 所需的时间, 该链路具有全双工带宽 WICI 和延迟 Tmin.
\[T_\text{total} = \max\left(\frac{T_\text{min} \cdot |X|}{2}, \frac{B}{W_\text{ICI}}\right)\]对于大的 B, 时钟时间相对恒定, 因为当你向系统添加更多芯片时, 你同时扩展了执行操作所需的数据移动量和可用的总带宽.
由于在延迟优化的推理期间移动的数据量相对较少, 激活上的集合操作通常受延迟项的限制 (特别是对于小批量大小). 通过计算我们需要完成的跳跃次数, 可以很容易地可视化延迟.
在 TPU 上, 如果通信的张量大小相关部分每跳 (一跳是两个相邻设备之间的通信) 小于 1 微秒, 我们可能会受到实际分派集合的固定开销的瓶颈. 单向 ICI 带宽为 4.5e10, 当 \((\text{字节} / n_\text{shards}) / 4.5e10 < 1e-6\) 时, ICI 通信变得受延迟限制. 对于 8 路 Megatron 分片, 这是当 buffer_size < 360kB 时. 这在推理期间实际上并不那么小: BS=16 和 D=8192 (int8), 我们的激活将使用 16*8192=131kB, 所以我们已经受延迟限制了.
要点: 当 \(\text{总字节数} < W_{ICI} \times 1e-6\) 时, 我们的通信会受到延迟限制. 例如, 在 \(Y\) 上进行模型并行时, 当 \(Y > BD / 45,000\) 时, 我们在 int8 中会受到限制.
这里可以与计算屋顶线进行类比 —— 我们正在承担一些小操作的固定成本 (通信的延迟, 矩阵乘法的内存带宽).
当我们真正关心端到端延迟时, 我们可以采用一个额外的技巧, 称为推测性采样
通过推测性采样, 我们使用一个更小, 更便宜的模型来生成 token, 然后用大模型检查结果. 这最容易用贪婪解码来理解:
为什么这是一个延迟上的胜利? 这个方案仍然要求我们为每个 token 执行相当于大模型一次前向传播的 FLOPs, 但因为我们可以将一堆 token 批处理在一起, 我们可以将所有这些 FLOPs 放在一次前向传播中完成, 并利用我们不受计算限制的事实来免费评分更多的 token.
平均而言, 每个被接受的 token 在 FLOPs 方面变得更加昂贵 (因为有些会被拒绝, 而且我们必须调用一个草稿模型), 但我们从硬件中榨取了更多的 FLOPs, 而且小模型很便宜, 所以我们总体上是赢的. 我们还在多个步骤之间共享 KV 缓存加载, 所以对于长上下文, 推测性解码也可以在吞吐量上取得胜利. 由于所有内容都经过了大模型的检查, 我们根本没有改变采样分布 (尽管对于非贪婪解码, 确切的轨迹会有所不同).
传统上, 推测性解码依赖于存在一个与目标模型具有相似采样分布的较小模型, 例如 LLaMA-2 70B 的 LLaMA-2 2B, 这通常不存在. 即使存在, 如果接受率很低, 较小的草稿模型仍然可能过于昂贵. 相反, 将草稿模型嵌入到主模型中可能会有所帮助, 例如通过在基础模型的后期层中添加一个专用的草稿头
对于正常的自回归采样, token/s 与步骤时间相同. 我们仍然受限于这里算术强度部分的理论最小步骤时间 (事实上, 推测性采样的步骤时间通常比正常的自回归采样慢得多, 但因为我们平均每步得到超过 1 个 token, 我们可以得到更好的 token/s).
这对于非贪婪解码是如何工作的? 这有点复杂, 但本质上归结为一个受 Metropolis-Hastings 启发的算法, 其中我们有从 logits 派生出的 \(P_\text{草稿模型}(\text{选择的 token})\) 和 \(P_\text{目标模型}(\text{选择的 token})\), 如果这些概率的比率小于某个阈值, 则概率性地拒绝选择的 token.
这两篇 论文 同时推导出了这一点, 并提供了很好的例子来说明这在实践中是如何工作的.
要点: 推测性采样是另一个强大的杠杆, 用于用吞吐量换取更好的每个 token 延迟. 然而, 在批量大小受限的情况下 (例如, 硬件占用空间小或 KV 缓存大), 它变成了双赢.