RL / OPD 工程 — 后训练的发动机
Ch18 把"OPD 为什么可行"讲完了;这一章只回答"它怎么跑得起"。trillion 级 teacher × 全词表 logit × 1M context × on-policy rollout 这四件事任一单独都能压垮集群,V4 用四个工程支柱(FP4 / Teacher Scheduling / WAL Rollout / Million-Token RL)把它们同时跑起。
四个支柱 = (1) FP4 在 rollout/teacher/reference forward 全量启用;(2) Teacher Scheduling 把 N 个 trillion teacher 拆到中央存储 + 中央 hidden 缓冲 + GPU 上动态装载的 prediction head;(3) Token 粒度 WAL 让 rollout 抢占可恢复且无 length bias;(4) Million-Token RL 把 1M trajectory 切轻量元数据 + 重 per-token 字段
一句话:把"原理上可行"翻译成"工程上能跑"。 OPD 的物理障碍是显存(trillion teacher)+ 带宽(全词表 logit)+ 容错(on-policy rollout 抢占)+ 长度(1M context) 四个独立瓶颈。每个支柱单独解决一个,缺一不可。 这一章是 V4 区别于其它开源蒸馏工作的真正护城河 —— 没有这一节,trillion teacher × 全词表 KL 是 PPT。
- FP4 全量启用(在 rollout / teacher / reference forward)
- Ch10 的 FP4 QAT 让训练时的 forward能用 FP4 权重 + 无损 dequant。OPD 阶段把这件事推到极致:所有 inference-only 路径(rollout、N 个 teacher 前向、reference 模型前向)全部走原生 FP4,只在 backward 路径维持 FP8。动机:teacher 与 reference 永远不更新,没必要 FP8 精度,FP4 砍内存 + 砍带宽。
- Teacher Scheduling(教师调度)
- OPD 框架要支持任意多个 trillion teacher同时算 logit。简单地"全部装显存"不可能(10 个 trillion ~ 10TB 权重)。V4 的方案:把 teacher 状态拆到三处—— ① 权重在中央分布式存储;② 前向 hidden state 在中央 buffer;③ 仅在需要时把对应 teacher 的 prediction head 动态装到 GPU。配合按 teacher index 排序的 batching,每个 head 整个训练 step 只装载一次。
- Centralized Weight Offload(权重托管)
- 所有 teacher 权重写到分布式存储(3FS 或类似),用 ZeRO-like 方式按需切片加载到 GPU。对比常驻 GPU:常驻需要每张卡至少 ~100GB teacher 显存;offload 让 teacher 与 student 共享显存预算,训练时显存不爆。代价是 I/O,但与计算 overlap 后净开销可控。
- 不实例化 logits(never materialize)
- 关键工程 trick。$|V| > 100\text{K}$ 全词表 logits × batch × seq 体积比模型本身还大(见 Ch18 数值演练 2.6 TB/step)。V4 在 teacher forward 期间只缓存最后一层 hidden state到中央 buffer;训练时再把它过对应 teacher 的 prediction head 现场重建 logits + 直接算 KL,logits 张量从来不在显存上完整出现。
- 按 Teacher Index 排序的 batching
- 训练样本在 dispatch 阶段按"它属于哪个 teacher"排序后再 batch。结果:每个 GPU 每个 step 只看一个 teacher 的样本 → 这个 teacher 的 prediction head 装载一次后整步复用 → head 装载成本被 batch 摊薄到接近零。是把"动态装载"做成可承担的关键。
- Async I/O(异步 I/O)
- 权重 / hidden state 的 load/unload 全部走后台流,不阻塞前向计算。等价于把 I/O 时间藏到计算下面(与 Ch7 MegaMoE 同思路)。需要双 buffer + 预取,但工程上是经典模式。
- TileLang KL Kernel
- student / teacher logits 的精确 KL 由专门 TileLang kernel(Ch8)计算,避免 PyTorch 动态显存分配带来的碎片。动机:每 step 算上万次 KL,PyTorch 的 op-by-op 调度会让显存 fragment 在几小时内累积爆炸;TileLang 的静态 layout 让 KL 张量始终在固定地址,0 fragmentation。
- WAL(Write-Ahead Log,写前日志)
- 数据库经典技术。每次状态变更先写到顺序日志,再改主数据。崩溃恢复时凭日志重放。V4 RL 借这个思路:rollout 每生成 1 个 token 就 append 到 trajectory log,preempt 时凭 log + 落盘 KV 直接续 decode,不必从头重生。
- Length Bias(长度偏差)
- 关键概念。短 trajectory 容易在 preempt 窗口内"完成",长 trajectory 容易被中断。如果中断后整体丢弃重跑,训练数据中"成功生成"的样本被系统性偏向短回答。policy 学到"短答案更可能成功"的伪信号 → distribution shift。Token 粒度 WAL 是为了让长 trajectory 也能"续上",从根上消掉这个 bias。
- Preemptible & Fault-Tolerant Rollout
- 大规模集群里 rollout 服务常被高优任务抢占(GPU 资源争用)+ 常发硬件故障。RL 训练的 rollout 必须能在两种事件下保持训练数据无偏。WAL 同时解决这两个:preempt 时凭 WAL + 落盘 KV 续 decode;硬件错误时凭 WAL 已生成 token 重做 prefill 重建 KV。
- 1M Context RL 的 trajectory 切两类
- 1M token 的单条 rollout trajectory 含 ~60MB KV cache + ~10MB hidden states + 元数据。如果整体作为 one record 在 dispatch 阶段全部 load,单节点显存秒爆。V4 把 trajectory 拆 lightweight metadata(prompt、reward、长度等)和 heavy per-token field(每 token 的 logit / hidden / KV),dispatch 只 load 元数据做 shuffle/packing,重字段用共享内存 data loader 按需加载。
1. FP4 在 OPD 的全量启用
Ch10 已经讲过 FP4 QAT 的核心 trick:FP32 master → FP4 → FP8 dequant 无损,让训练 forward 也能用 FP4 权重。OPD 阶段把这件事推到极致 —— 所有不需要 backward 的 forward 路径(rollout 推理、N 个 teacher 算 logit、reference 模型算 KL 基线)全部走原生 FP4。
- rollout 与所有 inference-only forward(teacher、reference)使用原生 FP4 (MXFP4);
- 训练 step 的 backward 仍走 FP8 主路径,FP4 → FP8 dequant 无损(Ch10 sub-block scale 嵌套),与现有 mixed-precision pipeline 无缝对接;
- 不用动 backward kernel,显存与带宽却整体砍半;
- 未来硬件如果做到原生 FP4×FP8 GEMM,单 token 还能再降 1/3 FLOPs。
- 纯 FP8:1 个 student 训练(设占用 1.0×)+ 10 个 teacher(10×)+ 1 个 reference(1×)+ rollout(1×)= 13×;
- FP4 全量 inference + FP8 backward:student 训练 1.0× + 10 teacher × 0.5 = 5× + 1 reference × 0.5 = 0.5× + rollout × 0.5 = 0.5× = 7×;
- 显存压力降到 7/13 ≈ 54%,等价于 GPU 数 ÷2。
2. Teacher Scheduling:让 N 个 trillion 教师同框
OPD 的核心瓶颈不是算 KL(那是廉价 GEMM),是把 N 个 trillion teacher 同时供给 KL 计算。论文给的工程套路是把 teacher 状态拆到三处,让任意时刻 GPU 上至多挂一个 teacher head:
- 中央分布式存储装 teacher 权重(按 ZeRO-like 切片),按需加载;
- 中央 buffer装最后一层 hidden state(teacher forward 的产物),训练时再过对应 prediction head 重建 logits;
- GPU 显存动态装载当前 batch 用到的 teacher prediction head(仅最后一层)。
配合按 teacher index 排序的 batching:训练样本在 dispatch 阶段先按"属于哪个 teacher"排序,再分发到 GPU。每个 GPU 在一个 step 里只看一个 teacher → prediction head 装载一次后整步复用 → 装载成本被摊薄到几乎零。
这里的核心是把"全部 teacher 全部 logit"这个不可能装进 GPU 的对象,分时分空间地散到三处:
- 权重:从中央存储按需切片加载 → I/O 是瓶颈,但与计算 overlap 可吃掉;
- Hidden state(每 token 一个 $d$ 维向量,$d=7168$):体积是 logits 的 $1/|V| \approx 1/18$,能装进 GPU 中央 buffer;
- Logits:从来不在显存里完整出现,用一次算一次扔一次,靠 hidden + prediction head 现场重建。
合起来:"GPU 永远只看到 1 个 teacher 的 1 个 head + 该 teacher 这一 batch 样本的 logit 流"。其它 N-1 个 teacher 的状态都在中央存储里"沉睡"。这就是把 trillion-teacher OPD 跑得起的物理基础。
3. WAL Rollout:可抢占可容错 + 消除 length bias
OPD 是 on-policy 的(Ch18 §4),每个 step 都要先用当前最新 student rollout 一批 trajectory,再算 KL。在大规模集群下两件事必然发生:
- 抢占:rollout 服务被高优训练任务挤占(GPU 资源争用),随时可能被 evict;
- 故障:硬件错误(GPU ECC、NVLink 抖动、磁盘故障)是常态,1000 卡 / 月 几十次。
最朴素的容错方案是把未完成的 rollout 整体丢弃重跑。这个方案有一个系统性偏差,是整章最容易被忽视但也最重要的洞察:
假设抢占窗口长度 $T$(比如 5 分钟),rollout 平均生成速率 $r$ token/s。任何 trajectory 长度 $L > rT$ 都必然被中断;长度 $L \le rT$ 才有机会"完成"。
- 长 trajectory(深度推理、长代码生成)大概率被丢;
- 短 trajectory 大概率"完成"被纳入训练数据;
- policy 看到的"成功完成"样本平均长度系统性低估;
- 反向梯度引导 policy 偏向短答案 —— 这是训练数据被 evict 机制污染。
这个偏差不是"性能损失"那么简单,它是分布层面的偏移。你训完一个 OPD model 发现它在长答题(奥数 / debug)上表现差,根因可能不是 OPD 损失设计错,而是length bias 偷偷重塑了训练数据。token 粒度 WAL 就是为了从根上斩断这条因果链。
V4 的设计:
- 每个生成请求维护 token 粒度 Write-Ahead Log —— 每生成 1 个 token 立刻 append 到 trajectory log;
- preempt 时暂停 inference,把整个 KV cache 落盘到分布式存储;
- resume 时凭 WAL + 落盘 KV 直接续 decode,从中断的 token 接着生成;
- 致命硬件错误(KV cache 不可恢复)时,凭 WAL 中已生成的 token 重做 prefill,重建 KV。
读图法:上排 6 条 trajectory 走"丢弃重跑"路径 —— 每次抢占(红段)整条丢、下次从 0 重生。越长的 trajectory 越难穿越完整抢占窗口,所以长 trajectory 完成率低,整体数据偏向短样本(length bias)。
下排同样 6 条走"WAL 续"路径 —— 抢占时落盘 KV(黄段),恢复后从 WAL 记录的 token 位置接着 decode。即使 trajectory 跨多个抢占窗口,最终都能完成。
把 trajectory 长度从 5K 拉到 100K 看上下完成率差距:丢弃重跑下长 trajectory 完成率断崖式下降,WAL 续下保持 ~100%。这就是 length bias 的视觉证据。
4. Million-Token RL 框架:把 trajectory 拆轻重
1M token 的单条 trajectory 一旦完整 load 进显存就秒爆。V4 的优化分两步:
- 切两类字段:每条 trajectory 拆成
- lightweight metadata:prompt、reward、长度、teacher 归属、状态码 等;
- heavy per-token field:每 token 的 hidden state、KV、logit;
- 分层加载:
- dispatch 阶段只 load metadata,做 global shuffle 与 packing layout 计算;
- per-token 重字段通过共享内存 data loader 按需加载,节点内多 GPU 共享一份,消除冗余;
- mini-batch 消费完立即释放,CPU/GPU 内存压力随时间稳定;
- on-device mini-batch 数量按 workload 动态决定,在计算吞吐与I/O overlap之间取最优解。
- metadata:~10 KB(prompt 头 + reward + 几十个标记字段);
- hidden states:1M × $d$=7168 × FP8 = ~7 GB;
- KV cache:~60 GB(CSA/HCA 压缩后已经省了 90%);
- logits(中央 buffer 缓存):根本不实例化(Ch19 §2);
5. 与 DSec Sandbox 的接缝
本章四节解决的是 RL/OPD 训练 + rollout 推理的工程问题。但 agentic rollout 还要执行外部命令(bash、文件、网页、单元测试),这部分的承载者是下一章的 DSec —— 整个 post-training 流水线由 "训练框架 + rollout 引擎 + sandbox 平台"三件套共同承担。
蒸馏在算法层不复杂 —— forward / reverse KL、on-policy 这些都是教科书级。但同时跑 10+ 个 trillion teacher + student + reward model + sandbox 仍能 24/7 不掉链子,这是开源社区从未做过的工程量级。
Ch19 + Ch20 合起来的工程深度,是 V4 区别于其它"做开源蒸馏"工作的真正分水岭 —— 不是论文里的公式新,是公式底下那层能让公式跑起来的水管新。
6. 一句话总结
四个支柱解决四个独立物理瓶颈:FP4 砍 inference-only forward 的显存与带宽、Teacher Scheduling 把 N 个 trillion teacher 拆到中央存储 + buffer + 动态 head 三层、token 粒度 WAL 让 rollout 抢占可恢复同时消除 length bias、Million-Token RL 把 1M trajectory 拆轻重字段分层加载。这一章是 V4 在后训练领域真正的护城河。