Ch 11 · 训练框架
第二部分 · 基础设施 · 11

训练框架 — mHC 的工程脚手架

Muon 与 ZeRO 为什么天然冲突、V4 怎么"稠密 / MoE 二分 + FP32→BF16 SR 通信"折中,mHC 怎么靠融合 kernel + DualPipe overlap 把 6.7% 开销压下去,以及 CSA / HCA 压缩边界让传统 CP 失效后的 two-stage CP 怎么救场。

名词速通 · 一分钟看懂训练框架

训练框架要解决三件事:(1) Muon × ZeRO 的"完整梯度 vs 切分梯度"冲突;(2) mHC 残差路径的 6.7% wall-time 开销;(3) CSA/HCA 把序列压成 1/m 后传统 CP 边界不对齐

一句话:把梯度按"稠密 / MoE"二分别切、把通信精度从 FP32 stochastic round 到 BF16 减半带宽、把 mHC forward/backward 融成单 kernel 嵌进 DualPipe 1F1B 流水、把 CP 切成两阶段先在压缩边界对齐再 all-gather。 每条都是为了让 V4 的架构层创新(Muon、mHC、CSA/HCA)在 1.6T 规模下真能跑得动。

ZeRO(Zero Redundancy Optimizer)
把 optimizer state、grad、param 沿数据维切到不同 rank,每 rank 只存一片。显存压力 ↓ N×,但每步训练需要 all-gather param(forward 用)+ reduce-scatter grad(backward 后)。ZeRO-3 是最激进版本,权重也切。
Muon × ZeRO 冲突
Muon 的 NS 迭代要看完整的更新矩阵 $M$(要做 $MM^T$)。ZeRO 把 $M$ 沿某一维切到各 rank,每个 rank 只看自己那片 → $MM^T$ 算不出来。这是架构与并行的根本冲突。
稠密 / MoE 二分切
V4 的折中方案:稠密参数(attention 投影、共享 FFN)限制 ZeRO 并行度 $\le P_{\max}$,用背包均衡把不同 shape 的稠密层打包到 $\le P_{\max}$ 个 rank;MoE 参数每个 expert 独立优化,flatten 所有 down/up/gate 后跨 rank 平均切,不限 ZeRO 并行度(因为 expert 之间彼此独立做 NS)。
背包均衡(knapsack-balanced sharding)
稠密层的 shape 千差万别(一些是 $7168\times 7168$,一些是 $7168\times 28672$),简单平均会让某 rank 严重超载。V4 把"shape 大小"看作物品体积,用背包近似算法把所有稠密层装到 $P_{\max}$ 个桶里,每桶总体积差距 ≤ 5%。最后 padding 对齐到最大 bucket。
FP32 → BF16 Stochastic Rounding(SR)
跨 rank 同步梯度时的带宽减半方案:发送方把 FP32 梯度随机舍入成 BF16 再发。$E[\text{round}] = $ 原 FP32 值(无偏),跨多步统计平均不引入偏差。比 round-to-nearest(有偏)好得多。
两阶段 all-to-all + 本地 FP32 求和
SR 的搭档:BF16 over wire → 接收方 BF16 升回 FP32 后 本地 求和(保 FP32 精度)→ 第二阶段 all-to-all 同步。这样通信带宽减半 + 数值精度保持。
DualPipe 1F1B
双向流水线(V3 引入,V4 沿用):每个 micro-batch 走"前向 → 后向"循环,1 个 forward 紧跟 1 个 backward 排队。可以把不同 micro-batch 的 forward 与 backward 重叠。给 mHC 的融合 kernel 提供 overlap 窗口
mHC 工程优化(融合 kernel + 选择性 ckpt)
mHC 残差路径的 wall-time 开销原本 ~12%(因为 8 次小 reduce + 8 次 sum 都是细粒度算子)。V4 用 TileLang 把 forward / backward 各融成 1 个 kernel,再用选择性 checkpoint(保留 normalized 输入、重算其余)省显存,再调 DualPipe 顺序让通信和 mHC 计算重叠 —— 压到 6.7%
Tensor 级 Activation Checkpointing
传统 ckpt 是 module 级(保存一个 transformer block 输入,重算整块)。V4 用 TorchFX 反向追溯最小重算子图,按 tensor 决定保留还是重算。无额外开销,更省显存
Two-stage CP(Context Parallelism)
长上下文沿序列维切到多 rank。CSA/HCA 把序列每 $m$ token 压成 1 个 → CP 边界经常落在压缩块中间,传统 CP 无法直接切。V4 的方案:Stage 1: 每 rank 把自己 $m$ 个未压缩 KV 发给下一 rank 合并压缩,固定输出长度 $s/m + 1$;Stage 2: 跨 rank all-gather + select-and-pad 重组完整压缩序列。
一句话定位:架构层(Part 1)每个新元素都在工程层有对应代价:mHC 用融合 kernel + ckpt 抹平、Muon 用稠密 / MoE 二分切 + SR 抹平、CSA/HCA 用 two-stage CP 抹平。这三件事合起来定义了 V4 的训练框架 —— 它的存在感几乎全部用来支付架构层创新的工程帐

1. Muon × ZeRO:完整梯度 vs 切分梯度

回顾 Ch06:Muon 每步要把更新矩阵 $M$ polar 化,用 NS 迭代 $M \leftarrow aM + b\,MM^TM + c(MM^T)^2 M$。问题:

  • $MM^T$ 是 $n \times n$ 的方阵,要 $M$ 完整才能算;
  • ZeRO-3 把 $M$ 沿某一维切到 $P$ 个 rank,每 rank 只有 $M$ 的一片;
  • 各 rank 算自己的 $M_i M_i^T$ 后不能直接相加 —— 因为切的是行还是列决定矩阵乘的语义。
📖 公式白话翻译 · 为什么不能"各算各的再加"

设 $M \in \mathbb{R}^{n\times m}$ 沿 $m$ 维切到 $P$ 个 rank:$M = [M_1 | M_2 | \cdots | M_P]$。要算的是:

$$ MM^T = \sum_{i=1}^P M_i M_i^T $$

这条等式里 $M_i M_i^T$ 是 $n \times n$ —— 即每 rank 都要算一个 full size $n \times n$ 中间矩阵。$n = 7168$ 时单个就是 200MB FP32,乘 $P=64$ rank → 13GB 临时占用。然后 reduce-sum 得到完整 $MM^T$ 才能继续 NS 迭代。

能做但贵。V4 的判断:稠密参数太大、不值得;MoE 的每个 expert 单独 ~$3.6$B 但 expert 之间独立、可以各 rank 一个 expert。所以稠密限制 ZeRO 并行度 $P_{\max}$,MoE 不限

1.1 稠密:背包均衡

稠密层 shape 不一致:QKV 投影 $7168\times 7168$、FFN up $7168\times 28672$、FFN down $28672\times 7168$ ……如果简单"每个层切到 P 个 rank",shape 大的层把 rank 拖慢、shape 小的层 rank 闲。V4 用背包均衡:把所有稠密层 flatten 后按总参数量分配到 $P_{\max}=8$ 个 rank,每 rank 总参数量误差 ≤ 5%。最后padding 对齐到最大 bucket,方便统一 NS。

1.2 MoE:每 expert 独立切

MoE 的 384 个 expert 彼此独立做 NS。所以每 expert 看作一个独立的"小模型",flatten 它的 down/up/gate 三个矩阵后跨 rank 切。不限 ZeRO 并行度 —— 反正每 expert 独立,rank 想切几片切几片,最后做完 NS 再合回来即可。

1.3 通信精度:SR FP32→BF16

两个最大的通信:reduce-scatter 梯度(每步)、all-gather param(每步)。原本 FP32 = 4 字节/数。SR 把发送方先 stochastic round 到 BF16 = 2 字节/数。

📖 公式白话翻译 · Stochastic Rounding

设 FP32 值 $x$ 落在两个 BF16 邻近码点 $\lfloor x \rfloor$ 和 $\lceil x \rceil$ 之间,距离比例 $p = (x - \lfloor x\rfloor) / (\lceil x \rceil - \lfloor x \rfloor)$:

$$ \mathrm{SR}(x) = \begin{cases} \lceil x \rceil & \text{w.p. } p \\ \lfloor x \rfloor & \text{w.p. } 1-p \end{cases} $$
  • 无偏:$E[\mathrm{SR}(x)] = p \cdot \lceil x \rceil + (1-p) \cdot \lfloor x \rfloor = x$;
  • vs round-to-nearest:RTN 是有偏的(小数总是往最近偶数靠),多步累积会漂移;SR 不漂移,跨多步取平均逼近真值
  • 方差:$\mathrm{Var}[\mathrm{SR}(x)] = p(1-p) \cdot (\lceil x \rceil - \lfloor x \rfloor)^2 \le \frac{1}{4} \cdot \mathrm{ULP}^2$,与 RTN 同量级。

对训练的影响:单步梯度噪声 $\uparrow$ 一点(相当于多了一些 SR 噪声),但训练本身就是 SGD,对噪声鲁棒。多步平均后逼近 FP32 同等收敛。带宽减半,质量几乎无损

2. mHC 的 6.7%:怎么从 12% 压下来

mHC 残差路径每层有 8 个 expansion / projection sum + reduce(见 Ch03)。原始 PyTorch 实现下:

  1. 每个 reduce / sum 是独立 kernel,host 开销大(参考 Ch08);
  2. 每层中间 activation 都要存(forward 16 个、backward 用全部),显存压力大;
  3. 反向链与主路径反向串行,无 overlap。

V4 三件事:

  1. 融合 kernel:TileLang 把 mHC forward 全 8 步合一个 kernel,backward 同理。host 开销从 16×30µs 降到 2×1µs
  2. 选择性 checkpoint:只保留 mHC 输入的 normalized 版本,其余 8 个中间 activation backward 时从 normalized 输入重算。重算开销 ~3% 但显存省 ~40%;
  3. DualPipe overlap:把 mHC backward 嵌到 DualPipe 1F1B 流水里和别的 micro-batch 的通信 overlap,掩掉 mHC 计算时间
数值演练 · mHC wall-time 拆账 基线 PyTorch eager mode:
  • 16 个独立 kernel × 30µs host + 16 × ~50µs device = 1.28 ms/层 × 61 层 × 4(fwd+bwd 双向)= 312 ms/step,约 12% wall-time;
融合 kernel + ckpt + DualPipe overlap:
  • 2 个融合 kernel × 1µs host + 2 × ~70µs device = 0.14 ms/层 × 61 × 4 = 34 ms;
  • + ckpt 重算 ~10 ms;
  • + overlap 后实际暴露 wall-time ~150 ms = 6.7%
这就是为什么 V4 敢推 mHC —— 残差成本可控。
Demo · 稠密 / MoE 二分切(拖动 P 看每 rank 显存占用)
交互
稠密参数(限制 P_max = 8) MoE 参数(不限 P) 稠密总 param ≈ 200B · 每 rank 占 25 GB · 总通信 BF16 SR 100 GB/step MoE 总 param ≈ 1.4T · 每 rank 占 175 GB · expert 独立 NS

读图法:每个色块代表 1 个 rank 持有的参数切片。稠密路径限制到 $P_{\max}=8$,超过 8 时每 8 个 rank 共享一份切片(多份冗余、显存不再降);MoE 路径不限,越切越细,显存随 P 线性下降
把 P 拉到 64 你能直观看到稠密区每 8 个 rank 还在重复同一片切片(颜色重),这就是稠密被限制在 $P_{\max}$ 的视觉表达

3. Two-stage CP:让 CSA/HCA 压缩边界与 CP 边界和谐共处

长上下文 CP 的意思:把 1M token 沿序列维切到 $P_{\text{cp}}$ 个 rank,每 rank 处理 $1\text{M}/P_{\text{cp}}$ token。每层 attention 时跨 rank 互通需要 KV 的 ring all-reduce。 问题:CSA / HCA 要把每 $m$ 个 token 压成 1 个,CP 边界经常落在某个压缩块中间 —— 那个边界 token 需要前后 $m$ 个 token 的信息,但前一段在邻居 rank 上。

V4 的两阶段方案:

  1. Stage 1: 每个 rank 把自己最末尾 $m$ 个未压缩 KV 发给下一个 rank。下一个 rank 收到后与本地的 $m$ 个 KV 合并,压缩成 1 个 "桥接 token"。固定输出长度 $s/m + 1$ 每 rank
  2. Stage 2: 所有 rank all-gather 自己的压缩 KV 序列 + bridging token,select-and-pad 重组成完整压缩序列,attention 在这个完整序列上做。
📖 公式白话翻译 · Stage 1 输出长度

每个 rank 输入 $s$ 个 token,本地能压成 $s/m$ 个完整压缩 token 加上 1 个 bridging(含本地末尾 + 邻居首端 $m$ 个 KV 合并而成):

$$ \text{Stage 1 输出长度} \;=\; \frac{s}{m} + 1 $$

"+1" 是 bridging token,它代表跨 rank 边界的 $m$ 个 token 联合压缩。这条精算让所有 rank 输出长度一致,方便 stage 2 的 all-gather。

4. Tensor 级 Activation Checkpointing

传统 module 级 ckpt:每个 transformer block 只存输入,反向时把整块重算。问题:mHC 的 16 个细 activation 重算开销大,占 ~10% 训练时间。

V4 用 TorchFX(PyTorch 的图变换库)做反向 trace:从 backward 计算图反向追溯哪些 tensor 被多次使用,决定保留还是重算的最小集。结果:

  • 对 mHC 的 8 个中间 sum 结果保留 normalized 输入,其余重算;
  • 对 attention 的 score 矩阵不存,反向用 Q/K 重算;
  • 对 FFN 中间 SwiGLU activation存 1 个,省其余两个;

总体显存 ↓ ~40%,重算开销 ~3% wall-time,无需手动标 ckpt 边界

本章小结

  • Muon 要完整梯度,ZeRO 切分梯度 → V4 把"稠密限制 $P_{\max}$ + MoE 不限"二分;通信精度 BF16 SR 减半带宽。
  • mHC 原始开销 12%,靠融合 kernel + 选择性 ckpt + DualPipe overlap 压到 6.7%
  • CSA/HCA 压缩边界让传统 CP 无法切,two-stage CP(邻居桥接 + all-gather)固定输出 $s/m + 1$ 解决。
  • Tensor 级 ckpt 省 ~40% 显存,重算开销 ~3% wall-time。
  • 合起来:架构层每一项创新都有对应工程代价,训练框架的存在感几乎全部用来"支付架构层的工程帐"。