<aside> 💡
TL;DR
我们将 speculative decoding 引入到了 RL 的采样流程中,在 batch size 合适的情况下,采样速度得到了显著提升;并且,draft model 也会在训练过程中更新。相较于冻结 draft model 的做法,accepted length 持续维持在较高水平,产生长期稳定的正收益。
</aside>
slime/docs/zh/advanced/speculative-decoding.md at main · THUDM/slime
Speculative Decoding 大名远扬,它是一种巧妙的推理加速技术。具体来说,推理过程中不再让昂贵的 Target Model 逐个 token 进行 decode,而是先由一个轻量级的 draft model 先进行 decode,生成多个 token 后,再由大模型进行批量验证。验证通过的 token 直接作为最终的推理结果,而验证失败的 token 重新用大模型进行采样。理想情况下,如果 draft model 生成的 token 都通过验证,系统就能一次性接受这 k 个 token,显著提高了推理效率。然而,如果 draft model 和 target model 的差异过大,能通过验证的 token 过少,反而可能会产生负面效果。
Speculative Decoding 这把双刃剑在工业级的推理引擎上已经得到了显著的应用,在 On-Policy RL 中也有着诱人的潜力。一方面,speculative decoding 能显著加速 Rollout 采样过程,并且采样得到的 token 在概率上和 target model 是完全一致的;此外,长尾轨迹会显著拉低 RL rollout 过程中的有效并发度,难以打满硬件的计算瓶颈,而这天然地适应了 speculative decoding 的应用场景。当然,这一切的前提是 draft model 和 target model 的采样概率差异在合理范围内——如果两者策略差异过大,draft model 推测出的 token 接受率会暴跌。
这就是本文解决的问题——我们将 speculative decoding 引入到了 RL 的采样流程中,并且随着训练的进行同步更新 draft model,稳定提高了采样速度。
目前已经合入 slime 主干,一键使用,参考文档
Megatron 在 v0.12.0rc3 中支持了 EAGLE MTP 的 SFT,基于此,我们考虑在训练过程中对 draft model 进行 online SFT。具体来说,我们在 Megatron backend 内部增加了一条新的 cross entropy loss(CE Loss),使用 target model 的 hidden state 和 generated token 作为 MTP 层(也即 draft model)的输入,以期望 MTP 层能准确预测出 target model 真正会生成的下一个 token。当 target model 的 GPRO Loss 计算完成并调用 backward() 时,同时触发 target model 的 GPRO Loss 和的 MTP CE Loss 的反向传播。具体流程如下:
为了分析我们对于 loss 的构造,我们首先来描述 draft model 的训练目标。和标准的 autoregressive model 中,我们的训练目标是用 $t$ 时刻的输入来预测 $t+1$ 时刻的 token,也即 input(t) -> output(t+1)。与此不同的是,主流的 speculative decoding 采用 eagle MTP 作为 draft model,其预测目标是 $t+2$ 时刻的 token,也即 Input(t) + Input(t+1) → Output(t+2)。