Ch 19 · RL / OPD 工程
第四部分 · 后训练 · 19

RL / OPD 工程 — 后训练的发动机

Ch18 把"OPD 为什么可行"讲完了;这一章只回答"它怎么跑得起"。trillion 级 teacher × 全词表 logit × 1M context × on-policy rollout 这四件事任一单独都能压垮集群,V4 用四个工程支柱(FP4 / Teacher Scheduling / WAL Rollout / Million-Token RL)把它们同时跑起。

名词速通 · 一分钟看懂"OPD 工程支柱"

四个支柱 = (1) FP4 在 rollout/teacher/reference forward 全量启用;(2) Teacher Scheduling 把 N 个 trillion teacher 拆到中央存储 + 中央 hidden 缓冲 + GPU 上动态装载的 prediction head;(3) Token 粒度 WAL 让 rollout 抢占可恢复且无 length bias;(4) Million-Token RL 把 1M trajectory 切轻量元数据 + 重 per-token 字段

一句话:把"原理上可行"翻译成"工程上能跑"。 OPD 的物理障碍是显存(trillion teacher)+ 带宽(全词表 logit)+ 容错(on-policy rollout 抢占)+ 长度(1M context) 四个独立瓶颈。每个支柱单独解决一个,缺一不可。 这一章是 V4 区别于其它开源蒸馏工作的真正护城河 —— 没有这一节,trillion teacher × 全词表 KL 是 PPT。

FP4 全量启用(在 rollout / teacher / reference forward)
Ch10 的 FP4 QAT 让训练时的 forward能用 FP4 权重 + 无损 dequant。OPD 阶段把这件事推到极致:所有 inference-only 路径(rollout、N 个 teacher 前向、reference 模型前向)全部走原生 FP4,只在 backward 路径维持 FP8。动机:teacher 与 reference 永远不更新,没必要 FP8 精度,FP4 砍内存 + 砍带宽。
Teacher Scheduling(教师调度)
OPD 框架要支持任意多个 trillion teacher同时算 logit。简单地"全部装显存"不可能(10 个 trillion ~ 10TB 权重)。V4 的方案:把 teacher 状态拆到三处—— ① 权重在中央分布式存储;② 前向 hidden state 在中央 buffer;③ 仅在需要时把对应 teacher 的 prediction head 动态装到 GPU。配合按 teacher index 排序的 batching,每个 head 整个训练 step 只装载一次。
Centralized Weight Offload(权重托管)
所有 teacher 权重写到分布式存储(3FS 或类似),用 ZeRO-like 方式按需切片加载到 GPU。对比常驻 GPU:常驻需要每张卡至少 ~100GB teacher 显存;offload 让 teacher 与 student 共享显存预算,训练时显存不爆。代价是 I/O,但与计算 overlap 后净开销可控。
不实例化 logits(never materialize)
关键工程 trick。$|V| > 100\text{K}$ 全词表 logits × batch × seq 体积比模型本身还大(见 Ch18 数值演练 2.6 TB/step)。V4 在 teacher forward 期间只缓存最后一层 hidden state到中央 buffer;训练时再把它过对应 teacher 的 prediction head 现场重建 logits + 直接算 KL,logits 张量从来不在显存上完整出现
按 Teacher Index 排序的 batching
训练样本在 dispatch 阶段按"它属于哪个 teacher"排序后再 batch。结果:每个 GPU 每个 step 只看一个 teacher 的样本 → 这个 teacher 的 prediction head 装载一次后整步复用 → head 装载成本被 batch 摊薄到接近零。是把"动态装载"做成可承担的关键。
Async I/O(异步 I/O)
权重 / hidden state 的 load/unload 全部走后台流,不阻塞前向计算。等价于把 I/O 时间藏到计算下面(与 Ch7 MegaMoE 同思路)。需要双 buffer + 预取,但工程上是经典模式。
TileLang KL Kernel
student / teacher logits 的精确 KL 由专门 TileLang kernel(Ch8)计算,避免 PyTorch 动态显存分配带来的碎片。动机:每 step 算上万次 KL,PyTorch 的 op-by-op 调度会让显存 fragment 在几小时内累积爆炸;TileLang 的静态 layout 让 KL 张量始终在固定地址,0 fragmentation
WAL(Write-Ahead Log,写前日志)
数据库经典技术。每次状态变更先写到顺序日志,再改主数据。崩溃恢复时凭日志重放。V4 RL 借这个思路:rollout 每生成 1 个 token 就 append 到 trajectory log,preempt 时凭 log + 落盘 KV 直接续 decode,不必从头重生
Length Bias(长度偏差)
关键概念。短 trajectory 容易在 preempt 窗口内"完成",长 trajectory 容易被中断。如果中断后整体丢弃重跑,训练数据中"成功生成"的样本被系统性偏向短回答。policy 学到"短答案更可能成功"的伪信号 → distribution shift。Token 粒度 WAL 是为了让长 trajectory 也能"续上",从根上消掉这个 bias。
Preemptible & Fault-Tolerant Rollout
大规模集群里 rollout 服务常被高优任务抢占(GPU 资源争用)+ 常发硬件故障。RL 训练的 rollout 必须能在两种事件下保持训练数据无偏。WAL 同时解决这两个:preempt 时凭 WAL + 落盘 KV 续 decode;硬件错误时凭 WAL 已生成 token 重做 prefill 重建 KV。
1M Context RL 的 trajectory 切两类
1M token 的单条 rollout trajectory 含 ~60MB KV cache + ~10MB hidden states + 元数据。如果整体作为 one record 在 dispatch 阶段全部 load,单节点显存秒爆。V4 把 trajectory 拆 lightweight metadata(prompt、reward、长度等)和 heavy per-token field(每 token 的 logit / hidden / KV),dispatch 只 load 元数据做 shuffle/packing,重字段用共享内存 data loader 按需加载。
一句话定位:这章是把 Ch18 OPD 的"为什么"和实际能跑之间的砍刀。四个支柱解决四个独立瓶颈:FP4 砍显存、Teacher Scheduling 砍 logit 体积、WAL 砍 length bias、Million-Token RL 砍 trajectory I/O。读完你应该能解释为什么 token-WAL 不是"普通断点续跑"而是消除 length bias 的关键。

1. FP4 在 OPD 的全量启用

Ch10 已经讲过 FP4 QAT 的核心 trick:FP32 master → FP4 → FP8 dequant 无损,让训练 forward 也能用 FP4 权重。OPD 阶段把这件事推到极致 —— 所有不需要 backward 的 forward 路径(rollout 推理、N 个 teacher 算 logit、reference 模型算 KL 基线)全部走原生 FP4。

  • rollout 与所有 inference-only forward(teacher、reference)使用原生 FP4 (MXFP4)
  • 训练 step 的 backward 仍走 FP8 主路径,FP4 → FP8 dequant 无损(Ch10 sub-block scale 嵌套),与现有 mixed-precision pipeline 无缝对接;
  • 不用动 backward kernel,显存与带宽却整体砍半
  • 未来硬件如果做到原生 FP4×FP8 GEMM,单 token 还能再降 1/3 FLOPs。
数值演练 · FP4 在 OPD 阶段省的显存 一次 OPD step:student backward 占 forward 显存的 ~50%,teacher × N + reference + rollout 都是 inference-only forward。
  • 纯 FP8:1 个 student 训练(设占用 1.0×)+ 10 个 teacher(10×)+ 1 个 reference(1×)+ rollout(1×)= 13×;
  • FP4 全量 inference + FP8 backward:student 训练 1.0× + 10 teacher × 0.5 = 5× + 1 reference × 0.5 = 0.5× + rollout × 0.5 = 0.5× =
  • 显存压力降到 7/13 ≈ 54%,等价于 GPU 数 ÷2。
这部分省下来的预算正好抵消 teacher 数量增加带来的成本,不增加集群规模就能多挂几个 teacher。FP4 的"砍精度但不砍能力"在 OPD 阶段是关键。

2. Teacher Scheduling:让 N 个 trillion 教师同框

OPD 的核心瓶颈不是算 KL(那是廉价 GEMM),是把 N 个 trillion teacher 同时供给 KL 计算。论文给的工程套路是把 teacher 状态拆到三处,让任意时刻 GPU 上至多挂一个 teacher head:

  1. 中央分布式存储装 teacher 权重(按 ZeRO-like 切片),按需加载;
  2. 中央 buffer最后一层 hidden state(teacher forward 的产物),训练时再过对应 prediction head 重建 logits;
  3. GPU 显存动态装载当前 batch 用到的 teacher prediction head(仅最后一层)。

配合按 teacher index 排序的 batching:训练样本在 dispatch 阶段先按"属于哪个 teacher"排序,再分发到 GPU。每个 GPU 在一个 step 里只看一个 teacher → prediction head 装载一次后整步复用 → 装载成本被摊薄到几乎零。

📖 三处状态拆分的物理账

这里的核心是把"全部 teacher 全部 logit"这个不可能装进 GPU 的对象,分时分空间地散到三处:

  1. 权重:从中央存储按需切片加载 → I/O 是瓶颈,但与计算 overlap 可吃掉;
  2. Hidden state(每 token 一个 $d$ 维向量,$d=7168$):体积是 logits 的 $1/|V| \approx 1/18$,能装进 GPU 中央 buffer
  3. Logits:从来不在显存里完整出现,用一次算一次扔一次,靠 hidden + prediction head 现场重建。

合起来:"GPU 永远只看到 1 个 teacher 的 1 个 head + 该 teacher 这一 batch 样本的 logit 流"。其它 N-1 个 teacher 的状态都在中央存储里"沉睡"。这就是把 trillion-teacher OPD 跑得起的物理基础。

3. WAL Rollout:可抢占可容错 + 消除 length bias

OPD 是 on-policy 的(Ch18 §4),每个 step 都要先用当前最新 student rollout 一批 trajectory,再算 KL。在大规模集群下两件事必然发生:

  • 抢占:rollout 服务被高优训练任务挤占(GPU 资源争用),随时可能被 evict;
  • 故障:硬件错误(GPU ECC、NVLink 抖动、磁盘故障)是常态,1000 卡 / 月 几十次。

最朴素的容错方案是把未完成的 rollout 整体丢弃重跑。这个方案有一个系统性偏差,是整章最容易被忽视但也最重要的洞察:

为什么"丢弃重跑"会引入 length bias

假设抢占窗口长度 $T$(比如 5 分钟),rollout 平均生成速率 $r$ token/s。任何 trajectory 长度 $L > rT$ 都必然被中断;长度 $L \le rT$ 才有机会"完成"。

  1. 长 trajectory(深度推理、长代码生成)大概率被丢;
  2. 短 trajectory 大概率"完成"被纳入训练数据;
  3. policy 看到的"成功完成"样本平均长度系统性低估
  4. 反向梯度引导 policy 偏向短答案 —— 这是训练数据被 evict 机制污染

这个偏差不是"性能损失"那么简单,它是分布层面的偏移。你训完一个 OPD model 发现它在长答题(奥数 / debug)上表现差,根因可能不是 OPD 损失设计错,而是length bias 偷偷重塑了训练数据。token 粒度 WAL 就是为了从根上斩断这条因果链。

V4 的设计:

  • 每个生成请求维护 token 粒度 Write-Ahead Log —— 每生成 1 个 token 立刻 append 到 trajectory log;
  • preempt 时暂停 inference,把整个 KV cache 落盘到分布式存储;
  • resume 时凭 WAL + 落盘 KV 直接续 decode,从中断的 token 接着生成;
  • 致命硬件错误(KV cache 不可恢复)时,凭 WAL 中已生成的 token 重做 prefill,重建 KV。
Demo · 抢占下"丢弃重跑" vs "WAL 续 decode" 的时间线 + 成功完成比例(拖动 trajectory 长度)
交互
"丢弃重跑" 方案:抢占即整段丢,下次重头来过 "WAL 续 decode" 方案:抢占时落盘,恢复时从 WAL 接 完成比例 = 0% · length bias = 严重 完成比例 = 0% · length bias = 生成中 抢占丢弃 落盘 KV 完成

读图法:上排 6 条 trajectory 走"丢弃重跑"路径 —— 每次抢占(红段)整条丢、下次从 0 重生。越长的 trajectory 越难穿越完整抢占窗口,所以长 trajectory 完成率低,整体数据偏向短样本(length bias)。
下排同样 6 条走"WAL 续"路径 —— 抢占时落盘 KV(黄段),恢复后从 WAL 记录的 token 位置接着 decode。即使 trajectory 跨多个抢占窗口,最终都能完成。
把 trajectory 长度从 5K 拉到 100K 看上下完成率差距:丢弃重跑下长 trajectory 完成率断崖式下降,WAL 续下保持 ~100%。这就是 length bias 的视觉证据

4. Million-Token RL 框架:把 trajectory 拆轻重

1M token 的单条 trajectory 一旦完整 load 进显存就秒爆。V4 的优化分两步:

  1. 切两类字段:每条 trajectory 拆成
    • lightweight metadata:prompt、reward、长度、teacher 归属、状态码 等;
    • heavy per-token field:每 token 的 hidden state、KV、logit;
  2. 分层加载
    • dispatch 阶段只 load metadata,做 global shuffle 与 packing layout 计算;
    • per-token 重字段通过共享内存 data loader 按需加载,节点内多 GPU 共享一份,消除冗余;
    • mini-batch 消费完立即释放,CPU/GPU 内存压力随时间稳定;
    • on-device mini-batch 数量按 workload 动态决定,在计算吞吐I/O overlap之间取最优解。
数值演练 · 1M token trajectory 的体积分布 单条 1M token trajectory:
  • metadata:~10 KB(prompt 头 + reward + 几十个标记字段);
  • hidden states:1M × $d$=7168 × FP8 = ~7 GB;
  • KV cache:~60 GB(CSA/HCA 压缩后已经省了 90%);
  • logits(中央 buffer 缓存):根本不实例化(Ch19 §2);
metadata : heavy = 1 : $10^7$。简单做"整体 load"等于把 metadata 也带上 GB 级负担,shuffle / packing 阶段的 dispatch 流水会被卡住。切两类后 dispatch 只过 metadata(KB 级),重字段按需 stream,工程效率提升数千倍。

5. 与 DSec Sandbox 的接缝

本章四节解决的是 RL/OPD 训练 + rollout 推理的工程问题。但 agentic rollout 还要执行外部命令(bash、文件、网页、单元测试),这部分的承载者是下一章的 DSec —— 整个 post-training 流水线由 "训练框架 + rollout 引擎 + sandbox 平台"三件套共同承担。

为什么这套基建是 V4 的真正护城河

蒸馏在算法层不复杂 —— forward / reverse KL、on-policy 这些都是教科书级。但同时跑 10+ 个 trillion teacher + student + reward model + sandbox 仍能 24/7 不掉链子,这是开源社区从未做过的工程量级。
Ch19 + Ch20 合起来的工程深度,是 V4 区别于其它"做开源蒸馏"工作的真正分水岭 —— 不是论文里的公式新,是公式底下那层能让公式跑起来的水管新。

6. 一句话总结

把整章压成一句话

四个支柱解决四个独立物理瓶颈:FP4 砍 inference-only forward 的显存与带宽、Teacher Scheduling 把 N 个 trillion teacher 拆到中央存储 + buffer + 动态 head 三层、token 粒度 WAL 让 rollout 抢占可恢复同时消除 length bias、Million-Token RL 把 1M trajectory 拆轻重字段分层加载。这一章是 V4 在后训练领域真正的护城河