如何扩展你的模型

TPU 上的 LLM 系统视图 (Part 0: Intro | Part 1: Rooflines)

训练 LLM (大型语言模型) 通常感觉像炼金术, 但理解和优化模型的性能并非必须如此. 本书旨在揭开语言模型扩展科学的神秘面纱: TPU (和 GPU) 如何工作以及它们之间如何通信, LLM 如何在真实硬件上运行, 以及如何在训练和推理过程中并行化你的模型, 使其在大规模上高效运行. 如果你曾想过“训练这个 LLM 应该有多昂贵”或“我自己需要多少内存来服务这个模型”或“什么是 AllGather”, 我们希望这本书对你有所帮助.

深度学习的很多方面仍然可以归结为一种黑魔法, 但优化模型性能并非必须如此 —— 即使是在巨大规模上! 相对简单的原则无处不在 —— 从处理单个加速器到数万个加速器 —— 理解它们可以让你做很多有用的事情:

背景要求: 我们假设你对 LLM 和 Transformer 架构有基本的了解, 但不一定了解它们如何大规模运行. 你应该了解 LLM 训练的基础知识, 最好对 JAX 有一些基本的熟悉. 一些有用的背景阅读可能包括关于 Transformer 架构的 这篇博客文章原始的 Transformer 论文. 另外, 请查看 这个列表 以获取更多有用的同步和未来阅读材料.

目标与反馈: 读完本书后, 你应该能够轻松地为给定硬件平台上的 Transformer 模型估算最佳并行方案, 以及大致的训练和推理时间. 如果你做不到, 请给我们发邮件或留言! 我们很想知道如何能让内容更清晰.

你可能也会喜欢阅读关于 NVIDIA GPU 的新 第 12 节!

你为什么应该关心?

三四年前, 我认为大多数机器学习研究人员不需要理解本书中的任何内容. 但如今, 即使是“小型”模型也运行得如此接近硬件极限, 以至于进行新颖的研究需要你考虑规模化的效率.从历史上看, 机器学习研究遵循着系统创新和软件改进之间的某种“滴答”循环. Alex Krizhevsky 不得不编写复杂的 CUDA 代码来使 CNN 变快, 但在几年内, 像 Theano 和 TensorFlow 这样的库意味着你不再需要这样做了. 也许这里也会发生同样的事情, 本书中的所有内容在几年后都将被抽象掉. 但是, 扩展定律已将我们的模型永久地推向了硬件的最前沿, 而且在不久的将来, 进行前沿研究似乎将与理解如何有效地将模型扩展到大型硬件拓扑结构密不可分. 在基准测试中取得 20% 的胜利, 如果以 20% 的屋顶线效率为代价, 那是无关紧要的. 有前途的模型架构之所以经常失败, 要么是因为它们无法在规模上高效运行, 要么是因为没有人投入精力去实现它们.

“模型扩展”的目标是能够在增加用于训练或推理的芯片数量的同时, 实现吞吐量的成比例线性增长. 这被称为“强扩展”. 尽管增加额外的芯片 (“并行化”) 通常会减少计算时间, 但它也带来了芯片之间通信增加的代价. 当通信时间超过计算时间时, 我们就会变得“受通信限制”, 无法实现强扩展.随着计算时间的减少, 你通常也会在单个芯片级别面临瓶颈. 你闪亮的新 TPU 或 GPU 可能额定每秒执行 500 万亿次操作, 但如果你不小心, 它同样很容易只做到十分之一, 如果它被在内存中移动参数所拖累. 单芯片计算, 内存带宽和总内存之间的相互作用对扩展至关重要. 如果我们对硬件有足够的了解, 能够预测这些瓶颈将在何处出现, 我们就可以设计或重新配置我们的模型以避免它们.硬件设计者面临着相反的问题: 构建能够为我们的算法提供恰到好处的计算, 带宽和内存, 同时最小化成本的硬件. 你可以想象这个“协同设计”问题有多么紧张: 你必须押注于当第一批芯片实际可用时算法会是什么样子, 这通常是 2 到 3 年之后的事情. TPU 的故事是这场博弈中一个响亮的成功. 矩阵乘法是一种独特的算法, 因为它每字节内存使用的 FLOPs (浮点运算次数) 比几乎任何其他算法都多 (每字节 N FLOPs), 早期的 TPU 及其脉动阵列架构在当时实现了比 GPU 好得多的性能/价格比. TPU 是为 ML 工作负载设计的, 而带有 TensorCores 的 GPU 也在迅速改变以填补这一空白. 但你可以想象, 如果神经网络没有兴起, 或者发生了某些根本性的变化, 那代价会有多大... [截断]

我们在本书中的目标是解释 TPU (和 GPU) 硬件如何工作, 以及 Transformer 架构如何演变为在当前硬件上表现良好. 我们希望这对于设计新架构的研究人员和致力于让当前一代 LLM 快速运行的工程师都有用.

内容大纲

本书的总体结构如下:

第 1 节 解释了屋顶线分析以及哪些因素会限制我们的扩展能力 (通信, 计算和内存). 第 2 节第 3 节 详细讨论了 TPU 的工作原理, 既包括作为单个芯片, 也包括 —— 至关重要的 —— 作为一个具有有限带宽和延迟的互连芯片的互连系统. 我们将回答以下问题:

图: 来自 第 2 节 的图表, 显示了 TPU 如何执行逐元素乘积. 根据我们数组的大小和各种链接的带宽, 我们可能会发现自己受计算限制 (使用全部硬件计算能力) 或受通信限制 (受内存加载瓶颈).

五年前, 机器学习拥有丰富多彩的架构格局 —— ConvNets, LSTMs, MLPs, Transformers —— 但现在我们主要只有 Transformer. 我们坚信, 了解 Transformer 架构的每个部分都是值得的: 每个矩阵的确切大小, 归一化发生在哪里, 每个部分有多少参数和 FLOPs浮点运算, 基本上是所需加法和乘法的总数. 虽然许多资料将 FLOPs 理解为“每秒操作数”, 但我们使用 FLOPs/s 来明确表示这一点.. 第 4 节 仔细地讲解了这种“Transformer 数学”, 展示了如何计算训练和推理的参数和 FLOPs. 这告诉我们模型将使用多少内存, 我们将在计算或通信上花费多少时间, 以及注意力何时会相对于前馈块变得重要.

图: 一个标准的 Transformer 层, 每个矩阵乘法 (matmul) 显示为一个圆圈内的点. 所有参数 (不包括范数) 都以紫色显示. 第 4 节 更详细地介绍了这个图.

第 5 节: 训练第 7 节: 推理 是本文的核心, 我们在其中讨论了基本问题: 给定某个大小的模型和一定数量的芯片, 我该如何并行化我的模型以保持在“强扩展”状态? 这是一个简单的问题, 却有着惊人复杂的答案. 从高层次上讲, 有 4 种主要的并行技术用于在多个芯片上拆分模型 (数据, 张量, 流水线专家), 以及许多其他技术来减少内存需求 (重物质化, 优化器/模型分片 (又名 ZeRO), 主机卸载, 梯度累积). 我们在这里讨论其中的许多技术.

我们希望在这些章节结束时, 你应该能够自己为新的架构或设置选择它们. 第 6 节第 8 节 是将这些概念应用于 LLaMA-3 (一个流行的开源模型) 的实践教程.

最后, 第 9 节第 10 节 介绍了如何在 JAX 中实现其中一些想法, 以及在出现问题时如何分析和调试代码. 第 12 节 是一个深入探讨 GPU 的新章节.

在整本书中, 我们都试图给你一些问题让你自己解决. 请不要有压力阅读所有章节或按顺序阅读. 并请留下反馈. 目前, 这是一个草稿, 将继续修订. 谢谢!

我们要感谢 James Bradbury 和 Blake Hechtman, 他们推导出了本文档中的许多想法.

话不多说, 这里是关于 TPU 屋顶线的第 1 节.

各章节链接

这个系列可能比它需要的要长, 但我们希望这不会阻止你. 前三章是预备知识, 如果熟悉可以跳过, 尽管它们介绍了后面使用的符号. 最后三个部分可能是最实用的, 因为它们解释了如何处理真实模型.

第一部分: 预备知识

第二部分: Transformers

第三部分: 实践教程

第四部分: 结论和附加内容

Miscellaneous

*Work done at Google DeepMind, now at MatX.

Citation

For attribution in academic contexts, please cite this work as:

    Austin et al., "How to Scale Your Model", Google DeepMind, online, 2025.

or as a BibTeX entry:

    @article{scaling-book,
      title = {How to Scale Your Model},
      author = {Austin, Jacob and Douglas, Sholto and Frostig, Roy and Levskaya, Anselm and Chen, Charlie and Vikram, Sharad
      and Lebron, Federico and Choy, Peter and Ramasesh, Vinay and Webson, Albert and Pope, Reiner},
      publisher = {Google DeepMind},
      howpublished = {Online},
      note = {Retrieved from https://jax-ml.github.io/scaling-book/},
      year = {2025}
    }