FP4 QAT — 显存的最后一刀
为什么 V4 敢把 MoE 专家权重和 CSA Indexer QK 路径降到 FP4,"FP4 → FP8 反量化无损"是怎么靠 sub-block scale 嵌套做到的,以及训练用模拟 FP4 / 推理用真 FP4 是怎么做到 bit-identical。
FP4 QAT = 量化感知训练(Quantization-Aware Training),把 MoE 专家权重存成 4-bit 浮点,forward 在 FP8 下算,反向梯度用直通估计回灌到 FP32 master 权重
一句话:用 4 bit 存权重(每个数 16 个码点),forward 把 4 bit dequant 到 FP8 算 GEMM,backward 梯度直接打回 FP32 master。 关键工程发现:FP4 sub-block (1×32) 和 FP8 量化块 (128×128) 的 scale 比值正好落在 FP8 E4M3 的动态范围内,因此 dequant 完全无损 —— 训练时模拟 FP4 与推理时真 FP4 字节相同。 收益:MoE 权重显存 ÷ 4,推理 throughput +1.5×。
- FP4 (E2M1)
- 4 bit 浮点:1 sign + 2 exponent + 1 mantissa。16 个可表示值:±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6(含特殊编码)。动态范围窄(max 6.0,min 0.5),所以必须配 scale 才能表达广动态权重。
- FP8 (E4M3)
- 8 bit 浮点:1+4+3。动态范围 ~448,常用于 forward GEMM 输入/输出。V4 GEMM 的内部精度。
- QAT(Quantization-Aware Training)
- 训练时就模拟量化:每步 forward 前把权重量化再反量化(fake quant),让模型自己学到量化噪声下的稳定参数。对比 PTQ(Post-Training Quant):PTQ 训完才量,质量损失大;QAT 训中量,质量几乎无损。
- FP32 Master 权重
- 每个参数还有一份 FP32 副本作为梯度累加目标。FP4 仅是"前向使用的视图"。这样小梯度不会被 FP4 的 0.5 步长吞掉。代价:每个参数多 4 字节内存;收益:训练稳定。
- STE(Straight-Through Estimator · 直通估计)
- 量化 $Q(\cdot)$ 的导数处处为 0(阶梯函数),梯度无法回传。STE 的 trick:反向时假装 $\partial Q / \partial w = 1$,把梯度直接穿过量化算子。粗暴但有效,是所有 QAT 的根基。
- Sub-block 嵌套 scale
- V4 的核心 trick:FP4 用 1×32 的小块,每块共享一个 FP8 scale;FP8 用 128×128 的大块,每块共享一个 FP32 scale。FP4 块的 FP8 scale 落在 FP8 E4M3 范围内 → 反量化时 scale 本身能精确表达 → 整个嵌套不引入舍入误差。
- STE 等价:梯度打到 FP8 权重
- V4 的具体实现:forward 用 FP8 反量化的权重,backward 把梯度直接累加到 FP8 中间表示(不是 FP4)。然后再把 FP8 累加结果同步回 FP32 master。这等价于对 FP4 量化算子的 STE,但实现更高效。
- Indexer FP4 + BF16 score
- CSA 的 Lightning Indexer 也走 FP4:QK 投影用 FP4 权重,indexer score $I_{t,s}$ 由 FP32 降到 BF16。实测top-k selector 提速 2×、99.7% recall 保留。这是把 FP4 推到注意力路径的勇敢一步。
- 真 FP4 推理 vs 模拟 FP4 训练
- 训练时硬件可能没原生 FP4 GEMM,于是用 FP8 模拟(dequant 到 FP8 再算);推理硬件(Hopper FP4 / Blackwell)有原生 FP4 GEMM,权重直接以 FP4 存盘。因为 sub-block 嵌套无损,两者输出 bit-identical。
1. 万亿参数的显存账
V4-Pro 1.6T 参数的存储压力极端:
- BF16: 1.4T × 2 B = 2.8 TB,需要 35× 80GB H100 仅装权重;
- FP8: 1.4T × 1 B = 1.4 TB,需要 18× H100;
- FP4: 1.4T × 0.5 B = 700 GB,9× H100 即可装下,配合 MoE only-load-active-experts 可以推到单节点 8 卡部署。
除了显存,FP4 GEMM 在新硬件(Hopper Tensor Core FP4、Blackwell)上吞吐是 FP8 的 2×,推理 throughput +1.5–2×。这是把 FP4 推上量化感知训练的另一动力。
2. FP4 (E2M1) 长什么样
FP4 只有 16 个可表示值。E2M1 的标准编码:
4 bit 编码 = sign(1) + exponent(2) + mantissa(1):
当 $e = 0$ 时是 subnormal,$v = (-1)^s \cdot 2^{-1} \cdot (m/2)$ —— 含 ±0, ±0.5。
16 个值的具体清单:$\{0, \pm 0.5, \pm 1, \pm 1.5, \pm 2, \pm 3, \pm 4, \pm 6\}$。
注意:没有 5(exp 跳跃)、有两个零(±0)、最大值 6.0。动态范围 6 / 0.5 = 12,远窄于 FP8 (动态范围 ~890)。所以不配 scale 直接表达不了 LLM 权重 —— 必须 sub-block scaling。
读图法:上轴是 FP4 16 个固定码点 —— 你能直观看到它们非均匀分布(小值密、大值稀,浮点的本性)。下轴是一个 sub-block 内 32 个真实权重;蓝→红 = 量化前→量化后。
点"长尾分布"看到一旦某个权重远大于其他(如 0.95 vs 其余 0.05),scale 被它单挑撑大,所有小权重一起塌进 0 附近的码点 —— 这就是异常值 outlier 问题,QAT 训练时模型自己学着不让权重出现长尾,所以训完的模型对 FP4 友好。
3. 关键 trick:sub-block 嵌套 scale
FP4 动态范围太窄,不能用 per-tensor 单一 scale(异常值会压垮所有正常权重)。常见做法 per-channel 或 per-group。V4 走更细:
- FP4 sub-block:每 1×32 个权重共享一个 scale,这个 scale 用 FP8 E4M3 存;
- FP8 大块:每 128×128 个权重(含 4 行 × 128 列 = 16 个 FP4 sub-block)共享一个 FP32 主 scale;
- 反量化:$w_{\text{FP8}} = \text{FP4}(w) \cdot s_{\text{FP8 sub}} \cdot s_{\text{FP32 主}}$。
关键观察:FP4 sub-block 的 scale 比例是相邻几行权重最大值之比,实测落在 $[2^{-7}, 2^{7}]$ 内(最多 2^7 = 128 倍差)。FP8 E4M3 的动态范围是~448(exp 4 bit, max ~$2^{8}$),刚好覆盖。
- scale 本身能精确存:不舍入;
- FP4 → FP8 dequant = 整数索引 × 精确 scale:每个 FP4 码点本身在 FP8 中也精确表达(FP4 的码点都是 FP8 的子集);
- 整体无损:dequant 后的 FP8 值 = 量化前的 FP4 视图原值,没有第二次舍入。
这是工程精算 —— 并不是 FP4 量化本身无损(它当然损失),而是 FP4 这个低精度视图"展开成 FP8"这一步无损。模型质量损失全部发生在"FP32 → FP4"那一次量化,QAT 让模型自己学着接受这个噪声。
4. QAT 训练循环
每步训练做的事:
- Forward:从 FP32 master 权重量化到 FP4(fake quant),dequant 到 FP8 输入 GEMM;
- Backward:梯度对FP8 表示的权重求导,等价 STE(绕过量化算子的 0 导数);
- Update:梯度累加到 FP32 master;
- 下一步:再次从 master 重新量化。
正常链式法则在量化算子 $Q$ 处会断:
$Q$ 是阶梯函数,处处导数为 0,梯度直接挂掉。STE 的 hack:反向时假装 $\partial Q / \partial w = 1$,让梯度原样穿过量化算子。等价于:
这条粗暴假设在数学上不严格,但在 QAT 实践中非常稳定:因为 $Q(w) - w$ 量级小($\le$ 半个码点间隔),梯度近似几乎一致;模型学几个 batch 后会自动靠近"量化友好"的权重。STE 是从 BinaryNet (2016) 一路继承下来的工程技。
5. CSA Indexer 的 FP4 路径
Ch04 的 CSA Lightning Indexer 是 V4 长上下文的命脉。每层一次 indexer 调用,每次要算整个 prefix 与当前 query 的 $QK^T$,然后选 top-k。FP4 化的关键:
- Indexer Q/K 投影权重:FP4 存储 + FP8 计算;
- Indexer score $I_{t,s}$:原本 FP32 累加,降到 BF16;
- top-k selector:BF16 排序,相对 FP32 提速 2×。
实测recall@k 99.7% —— 即和 FP32 indexer 选出来的 top-k 重合 99.7%。剩下 0.3% 的差异是 "BF16 score 排序产生的边界 token 翻号",由于 sparse attention 本身对 top-k 边缘不敏感,下游 logits 损失忽略不计。
本章小结
- 1.6T MoE 权重在 BF16 下 2.8 TB,FP4 砍到 700 GB,单节点 8×H100 可装。
- FP4 (E2M1) 只有 16 个码点、动态范围 12,必须 sub-block 嵌套 scale。
- FP4 sub-block (1×32) + FP8 块 (128×128) 的 scale 比例落在 FP8 E4M3 范围内 → FP4→FP8 反量化无损。
- STE 让 QAT 训练通过量化算子,反向把梯度打回 FP32 master。
- Indexer 的 FP4 + BF16 score 让 top-k selector 2× 提速 + 99.7% recall。
- "训练用模拟 FP4、推理用真 FP4"两侧 bit-identical,靠的就是上面这套嵌套精算。