筆記

Pre-training + RLHF 可以讓我們訓練出 GPT-3.5,而這節課我們將學習到如何訓練出 o1/r1

Core Algorithm

在進一步探討 RLHF 之前,我們先來回顧 RL 演算法的演進過程:

  1. Policy Gradient: 這是最基礎的演算法,目標是直接最大化策略下的期望獎勵。但它的缺點是 variance 極高,且屬於「on-policy」方法,也就是每次更新前都必須從當前的 policy 重新 sampling rollouts,效率十分低落。
  2. **TRPO:**為了解決 sampling efficiency 的問題,TRPO 引入了 importance sampling,允許模型圍繞當前的 policy 進行線性化,並利用舊策略的 rollouts 來進行更新。其核心在於限制新舊策略之間的差異不能太大。
  3. PPO: PPO 進一步簡化了 TRPO 的概念,採用 clipping 機制。它將新舊策略 的 probability ratios 限制在一個範圍內(例如加上或減去 $\epsilon$),這能防止策略偏離原有策略太遠,確保訓練的穩定性。

以下是 PPO 的演算法概念

image.png

雖然 PPO 在理論層面的目標函數概念很簡單,但在實務上卻極為複雜。有一篇名為《The 37 Implementation Details of Proximal Policy Optimization》的文章,強調 PPO 實作中充滿了各種可能導致失敗的細微設定。

在將 PPO 應用於語言模型時,系統需要同時運作多個神經網路模型:包含策略模型 (Policy LM)、監督式微調模型 (SFT Model)、獎勵模型 (Reward Model) 以及價值模型 (Value Model)。這不僅架構龐大,更會消耗雙倍的 GPU 記憶體資源

image.png

PPO in practice

接下來我們會透過 AlpacaFarm 的 PPO 實作,拆解了 PPO 的具體運作流程:

Outer loop 與 Rollouts

PPO 透過外層迴圈呼叫內層函數,先使用當前模型進行推理並生成多個回應(即 Rollouts,這是最耗時的步驟),接著對這些樣本計算損失並進行反向傳播梯度更新,同時會應用梯度裁剪 (clip grad norm) 等技巧。

https://github.com/tatsu-lab/alpaca_farm/blob/main/src/alpaca_farm/rl/rl_trainer.py#L128

		def step_with_rollouts(self, rollouts):
        """Based on fixed rollouts, run PPO for multiple epochs."""
        assert isinstance(self.optimizer, AcceleratedOptimizer), (
            "`optimizer` must be pushed through `accelerator.prepare`. "
            "Otherwise the `accelerator.accumulate` context manager won't correctly disable `zero_grad` or `step`."
        )
        rollouts_dataloader = self.get_rollouts_dataloader(rollouts=rollouts)
        stats_list = []
        for epoch_idx in range(self.args.noptepochs):
            for batch_idx, rollouts_batch in tqdm.tqdm(
                enumerate(rollouts_dataloader, 1), disable=not self.accelerator.is_main_process, desc="gradstep"
            ):
                with self.accelerator.accumulate(self.policy):
                    ppo_loss, stats_for_this_step = self.compute_loss(rollouts_batch)
                    self.accelerator.backward(ppo_loss)
                    if self.accelerator.sync_gradients:
                        # Gradient norm almost blows up at some point, but stabilizes eventually, even w/o clipping.
                        if self.args.max_grad_norm is not None:
                            self.accelerator.clip_grad_norm_(self.policy.parameters(), self.args.max_grad_norm)
                        stats_for_this_step["loss/grad_norm"] = self._compute_grad_norm()
                        stats_list.append(stats_for_this_step)
                    self.optimizer.step()
                    self.optimizer.zero_grad(set_to_none=True)
        return common.merge_dict(stats_list, torch.stack)  # list of dict -> dict: str -> 1-D tensor

Loss Computation