← 返回首頁

克服 TRL 中訓練器與生成器之間的精度不匹配問題

Dirhousssi Amine
Dirhousssi Amine
@DirhousssiAmine
63🔁 6
𝕏 (Twitter)🔥

克服 TRL 中訓練器與生成器之間的精度不匹配問題

數值精度差異導致的「幻影 PPO 裁剪」阻礙了 RL 收斂。

簡短總結:我們發現當訓練的前向傳遞(FP32)與 vLLM 推論伺服器(BF16)使用不同的數值精度時,非同步 GRPO 訓練會失敗。根本原因是精度差距觸發了「幻影 PPO 裁剪」(phantom PPO clipping),導致那些策略實際上並未改變的 token,其梯度訊號被歸零。

我們最近在 TRL 中實作了 AsyncGRPO 演算法,旨在解耦推論與訓練,以實現更大規模的快速 RL 訓練。為了驗證實作,我們設置了最簡單的測試案例:

  • 任務:Reward = -len(completion_tokens)。最佳策略是立即輸出 EOS(獎勵 = -1)。

  • 模型:Qwen3-0.6B(28 層,hidden_dim=1536,詞彙表=151,936)。

def negative_length_reward(completion_ids, **kwargs):
    return [-len(ids) for ids in completion_ids]

trainer = AsyncGRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    args=config,
    train_dataset=dataset,
    reward_funcs=negative_length_reward,
)
trainer.train()

任何有效的 RL 演算法都應該在幾步之內收斂。令人驚訝的是,使用預設的 FP32 精度執行此腳本卻無法收斂。

這種現象並非個案。近期的研究已指出數值精度是 RL 微調中不穩定性的來源。Qi 等人 (2025) 證明,由 BF16 捨入(rounding)引起的訓練-推論不匹配,破壞了生成 rollout 的策略與計算梯度的策略之間的一致性,並表明改回 FP16 可以消除該問題。Megatron-Core MoE 報告同樣指出:「在強化學習訓練期間,半精度浮點數(FP16)在某些超參數選擇下可以提供更高的數值穩定性」,並提供了一條專用的 FP16 訓練路徑。然而,這些研究都沒有提供關於為什麼這種不匹配會導致訓練失敗的機制解釋。例如,Qi 等人將問題追溯到兩個相互交織的現象:注意力機制中出現了相似的低秩表示,以及低精度算術中固有的偏差捨入誤差的累積效應。該論文正確地識別了這些現象,但未能提供完整的因果鏈。我們在此的目標是找出「為什麼」:逐步剖析 BF16 精度不匹配如何破壞 GRPO 梯度並阻止收斂的確切機制。那麼根本原因是什麼?僅僅是模型權重之間的精度不匹配,還是優化器中更深層的問題?正如我們將在這篇(長篇!)部落格文章中展示的那樣,答案是 PPO 的裁剪機制與 BF16 捨入引入的數值雜訊之間微妙的互動:精度差距觸發了我們稱之為「幻影裁剪」的現象,優化器會將那些策略實際上並未改變的 token 的梯度訊號消除。

使我們的設置特別適合研究此問題的原因在於其簡單性。立即輸出 EOS 的任務具有已知的最佳策略、無歧義的稠密標量獎勵,且收斂(或缺乏收斂)在 100 步內即可觀察到。結合 TRL 中簡潔、最小化的 AsyncGRPO 實作,這為我們提供了一個完全可重現、易於探測的環境,讓我們可以隔離並精確測量 BF16 精度損失進入訓練管道的位置,以及它如何阻止收斂。

我們研究的架構 AsyncGRPO 解耦了生成與訓練:vLLM 推論伺服器以非同步方式生成 BF16 格式的補全內容,而訓練程序則計算梯度並更新權重。當訓練的前向傳遞使用與 vLLM 不同的數值精度時,精度不匹配就會進入訓練管道。我們將在後續章節中詳細說明這具體是如何發生的。

在展示結果之前,讓我們定義控制訓練管道數值行為的兩個精度旋鈕:

  • DTYPE(模型權重載入 dtype):儲存的模型參數精度。當設置為 float32 時,優化器維持全精度權重。當設置為 bfloat16 時,權重本身以 BF16 儲存,這意味著優化器更新也會在 BF16 中累積。

  • Autocast(torch.amp BF16=True/False):控制前向傳遞矩陣乘法是否使用硬體加速的 BF16 GEMM。當 BF16=True 時,所有矩陣乘法在執行前都會將運算元轉換為 BF16,與 vLLM 的推論精度相匹配。

我們進行了實驗,改變了基礎權重 dtype、autocast 精度和學習率。

注意:我們所說的「收斂」是指模型獲得足夠的梯度訊號以穩定提高其獎勵,而不是指它在 100 步內達到 -1 的最佳值。

以下是實驗總結表:

作為健全性檢查,我們使用標準的同步 GRPOTrainer(TRL 中經過實戰驗證的實作)而不是我們的非同步變體重複了相同的實驗。下方的結果證實了這些發現:相同的收斂行為出現在相同的精度配置中,證實了失敗並非非同步架構的產物,而是 FP32/BF16 精度不匹配如何與 GRPO 損失互動的基本屬性。

模式很明顯:當訓練前向傳遞與推論引擎使用不同的有效精度,且學習率太小而無法克服由此產生的不匹配時,收斂就會失敗。本報告的其餘部分將詳細剖析這種失敗機制。

這是為你早晨喝咖啡時準備的輕鬆精簡版 ☕。若要查看包含審閱章節、附錄、證明和互動式動畫圖表的深入版本,請查看包含所有精華內容的完整版本。



精度誤差來源

在訓練期間,前向和後向傳遞中的大多數算術運算都是以 BF16 進行的,儘管一些對數值敏感的運算(例如歸一化或歸約)通常以更高的精度計算。由於 BF16 將每個值捨入到僅 8 個有效位元,因此數值誤差可能會在 GRPO 管道的每個階段悄悄出現:在前向傳遞計算 Logits 和對數機率時、在後向傳遞傳播梯度時,以及在每次與推論引擎進行權重同步時將 FP32 權重截斷為 BF16 時。我們在下方依次描述這些誤差來源。

前向傳遞 Logit 誤差

BF16 矩陣乘法將每個運算元捨入到 8 個有效位元,注入每個值 2^-8 的相對誤差。這些逐層誤差透過殘差流累積,根據 L 層和隱藏維度 d 的中心極限定理(CLT)論證,總隱藏狀態誤差縮放為 √L · √d · σ · 2^-8。對於 Qwen3-0.6B(L=28, d=1536),這預測 Logit 誤差約為 0.8σ,與第 6 節中測得的 |β| ≈ 0.076 相符。

對數機率誤差

Token a_t 的對數機率是整個 Logit 向量 z = [z_1, z_2, ..., z_|V|] 的函數:

令 z 為 FP32 Logits,δz = [δz_1, ..., δz_|V|] 為來自 BF16 的每個 Token 的 Logit 誤差,因此 z_bf16 = z + δz。一階泰勒展開:

微分並代入(完整推導請見完整版本):

對數機率誤差是所選 Token 的 Logit 誤差減去整個詞彙表中機率加權的平均 Logit 誤差。雖然 Log-softmax 是平移不變的(添加到所有 Logits 的常數偏移 C 會抵消),但實際上 BF16 捨入誤差從來都不是均勻的。BF16 網格是一個階梯函數,其步長(ULP)取決於每個值的指數。不同大小的 Logits 位於不同的指數區間中,並以不同的步長進行捨入。

權重同步截斷

在每次權重同步時,訓練程序會將 FP32 權重發送給 vLLM:

每個權重的誤差:|W_train - W_vllm| ≤ ½ ULP(W_train)。

Adam 的更新規則是 ΔW = -η · m̂_t / (√v̂_t + ε)。梯度大小會抵消,使得 |ΔW| ≈ η,無論損失函數的形狀如何。當 |W| = 1.0, lr = 10^-6,且 ULP(1.0) = 2^-7 = 0.0078 時,BF16 表示僅在累積更新跨越中點時才會改變:

在 100 步的訓練執行中,此權重在 BF16 中永遠不會改變,這正是 BF16 量化的邊界跨越問題(完整版本中有涵蓋)。

α/β 分解

上一節列出了 BF16 捨入誤差進入 GRPO 管道的三個位置。無論它們源自何處,最終都體現在同一個地方:模型分配給每個 Token 的對數機率。由於 GRPO 損失取決於當前策略與 Rollout 策略之間的對數機率「比率」,因此每個 BF16 誤差來源最終都會匯入一個單一的差異:log π_θ(a_t) - log π_old(a_t)。

這促使我們將該對數比率分解為一個即使在精確的 BF16 算術下也會存在的組分,以及一個純粹由訓練與推論之間的精度不匹配引起的殘差。

由於 W_j(vLLM 在 Rollout 時使用的權重)在訓練進展到 W_k 後已不存在,我們透過插入樞紐 f^(bf16)(a_t; W_k)(即在當前權重下的局部 BF16 前向傳遞)來進行分解:

其中:

  • P 是訓練精度(FP32 或 BF16 autocast)。

  • W_k 是當前的訓練權重(FP32)。

  • W_j 是 vLLM 用於生成 Rollout 的權重(j < k,其中 k - j 是可容忍的滯後時間)。

這種分解是可測量的:α_t 和 β_t 都可以透過在當前批次上執行局部的 BF16 影子前向傳遞(shadow forward pass)在每個訓練步驟中計算出來。我們在第 4.1 節中詳細說明了此影子前向傳遞的實作。

對數比率分解為 α(合法的 BF16 對齊策略變更)和 β(精度差距)。當 BF16=True 時,β 消失,比率是乾淨的。當 BF16=False 時,β 佔比率的很大一部分。

項 α_t:BF16 對齊比率

α_t 捕捉了自 Rollout 以來在 BF16 空間中發生的所有變化:BF16 可見的策略變更、vLLM 計算路徑不匹配等。

關鍵洞察:Async GRPO 中合法的權重採樣(importance-sampling)校正透過 α_t 運作。

項 β_t:精度差距

β_t 是純粹的局部精度差距:訓練前向傳遞(精度 P)與 BF16 前向傳遞在相同權重 W_k 上計算對數機率的差異。如果 P = BF16(autocast 或 bf16=True):

請注意,β_t 並非完全為零,因為 vLLM 的計算路徑與訓練端的 Transformer 實作略有不同(不同的注意力核心、不同的融合模式),但在實務上殘差可以忽略不計。

如果 P = FP32(無 autocast 或 bf16=False):

β_t 是與 Token 相關的:不同的 Token 會啟用 LM 頭中不同的權重行,從而產生不同的捨入模式。

該理論預測匹配精度時 β = 0,而對於不匹配精度時 |β| ~ O(0.01, 0.1)。但 β 真的是會被平均掉的隨機雜訊,還是具有系統性破壞學習結構的雜訊?我們現在有了測量儀器,一種在每個訓練步驟中分離訊號與雜訊的方法。是時候檢查證據了。


在即時訓練中測量 α 和 β

在立即輸出 EOS 任務上進行兩次執行(Qwen3-0.6B,100 步,lr=1e-6,BF16 vLLM):

  • 執行 A(收斂):DTYPE=float32,BF16=True

  • 執行 B(失敗):DTYPE=float32,BF16=False

在每個訓練步驟中,同一批次上的 BF16 影子前向傳遞會分解對數比率:

# 在同一批次上計算 BF16 影子對數機率(模擬 vLLM 評估)
lp_lowp = self._compute_low_precision_log_probs(model, input_ids, attention_mask, completion_mask)
lp_lowp = lp_lowp[:, : log_probs.shape[1]]

# log_ratio = alpha + beta,其中:
#   alpha = lp_lowp - old_log_probs  (訊號:自 Rollout 以來的 BF16 策略變更)
#   beta  = log_probs - lp_lowp      (雜訊:訓練與 BF16 函數不匹配)
alpha = (lp_lowp - old_log_probs)[valid_mask].float()
beta = (log_probs - lp_lowp)[valid_mask].float()

# 記錄每步統計數據
beta_mean = beta.abs().mean().clamp(min=1e-12)
snr = alpha.abs().mean() / beta_mean

_compute_low_precision_log_probs 輔助函數:

@torch.no_grad()
def _compute_low_precision_log_probs(self, model, input_ids, attention_mask, completion_mask):
    """執行 BF16-autocast 前向傳遞以模擬 vLLM 的評估方式。"""
    original_forward = getattr(model, "_original_forward", None)
    fwd_fn = original_forward if original_forward is not None else model.forward
    with torch.amp.autocast("cuda", dtype=self._low_precision_dtype):
        outputs = fwd_fn(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
    logits = outputs.logits[:, :-1, :].float()
    logits.div_(self.temperature)
    return selective_log_softmax(logits, input_ids[:, 1:])

精度差距 β

每步的平均 |β|。BF16=True 精確產生 β=0;BF16=False 顯示持續的精度差距約為 0.076。

對於 BF16=True,β 精確為 0。Autocast 訓練前向傳遞與 BF16 影子前向傳遞產生相同的對數機率,證實了第 3.2 節的理論預測。

對於 BF16=False,β 是顯著且結構化的:

  • 平均幅度 0.076,個別 Token 最大可達 3.05。

  • 有符號平均值 -0.0105:FP32 前向傳遞系統性地產生比 BF16 更低的對數機率,這是一種一致的負偏差,而非以零為中心的雜訊。

  • 離散度 std = 0.149:每個 Token 的 β_t 差異很大。有些 Token 得到 β ≈ +0.15(比率膨脹約 16%),有些得到 β ≈ -0.15(比率緊縮約 14%)。

  • 與優勢(advantage)的相關性 +0.0094:精度不匹配系統性地過度加權了具有良好優勢的 Token,並低估了具有不良優勢的 Token。

BF16 對齊比率 α

我們現在轉向 α,這是對數比率中反映 BF16 空間中實際策略變更的組分。如果優化器正在進行有效的更新,|α| 應該隨著策略偏離 Rollout 策略而隨訓練增加。α 增加(或未能增加)的程度告訴我們訓練訊號在多大程度上被「部署」到了 vLLM 所服務的 BF16 模型中。

兩次執行開始時都有相似的 |α|(約 0.035)。執行 A 的 α 隨著時間推移增長得更大(高達 0.92),表明 BF16 策略正在積極偏離舊的 Rollout,模型正在學習。執行 B 的 α 增長較慢(高達 0.33),表明訓練訊號未能有效地傳遞到已部署的 BF16 權重中。

訊號雜訊比(SNR)

α 和 β 的個別幅度是有資訊量的,但決定訓練能否成功的量是它們的「比率」。如果 |β| > |α|,精度雜訊就會主導真實的策略變更訊號,優化器本質上是在雜訊中導航。相反,如果 |α| ≫ |β|,精度差距只是一個微小的擾動,訓練可以容忍它。

對於 BF16=False,SNR 開始時低於 1.0(雜訊主導),訓練期間平均約為 3。

對於 BF16=False,|α|/|β| ≈ 3。精度差距約為總對數比率幅度的 1/3。在訓練初期(第 1 到 3 步),SNR 低於 1.0,這意味著精度差距佔據主導地位。

每步的部署改進

到目前為止的指標描述了優化器所看到的內容。但重要的問題是:每個優化器步驟是否真的有助於已部署的 BF16 策略?

為了直接測量這一點,我們使用了第 4.1 節中介紹的相同 BF16 影子前向傳遞(_compute_low_precision_log_probs 輔助函數)。在每個優化器步驟之前,我們記錄每個 Token 的 BF16 對數機率。步驟完成後,我們測量 BF16 對數機率如何變化,以及該變化是否與優勢方向一致:

# on_step_end 回呼。模型權重已更新
lp_after = t._compute_low_precision_log_probs(t.model, input_ids, attention_mask, completion_mask)
delta = (lp_after - lp_before).float()
adv_sign = torch.sign(advantages)
n_valid = valid.sum().clamp(min=1)

# deployed_improvement:BF16 對數機率是否朝優勢方向移動?
aligned = (delta * adv_sign * valid.float()).sum() / n_valid
t._metrics["train"]["qat/deployed_improvement"].append(aligned.item())

BF16=True 每個步驟實現的有效改進是 BF16=False 的 5.5 倍。

每個優化器步驟在 BF16=True 下比 BF16=False 更有效地改進了 BF16(已部署)策略 5.5 倍。兩種設置每步移動 BF16 函數的絕對量相似(約 0.016),但 BF16=True 的移動與優勢方向的對齊程度要好得多。BF16=False 的 deployed_improvement 幾乎為正(+0.00023),本質上是零附近的雜訊。

β 如何破壞梯度

封閉形式的梯度失真

定義分數函數 s_t = ∇_W log π_θ(a_t),即在權重空間中使 Token a_t 更有可能出現的方向。

簡化假設:完整梯度包含裁剪指標 C_t。在本節中,我們分析梯度時假設所有 Token 都有貢獻(對所有 t,C_t = 1)。這隔離了 β 的乘法效應和分數函數效應。我們將在第 7 節重新審視此假設。在此簡化下,乾淨梯度(BF16=True,β = 0):

實際梯度(BF16=False,β ≠ 0):

代入 e^(α_t + β_t) = e^(α_t) · e^(β_t) 和 s_t^(fp32) = s_t^(bf16) + δs_t(完整推導請見完整版本)

其中:

有效優勢失真

比率失真可以被吸收進優勢中:

當 corr(β_t, A_t) > 0 時(測得:+0.0094):

梯度失去了良好補全與不良補全之間的對比度。嚴重程度取決於 β_t 的符號和幅度:由於 e^(β_t) 是凸函數,正的 β 值會乘法放大優勢(例如,e^(0.15) = 1.16,增加 16%),而負的 β 值會減弱它(例如,e^(-0.15) = 0.86,減少 14%)。因為 corr(β_t, A_t) > 0,具有良好優勢的 Token 傾向於得到正的 β(過度強化),而具有不良優勢的 Token 傾向於得到負的 β(過度抑制)。淨效應是有效優勢分佈的系統性壓縮。優化器看到的最佳補全與最差補全之間的差異比實際存在的要小。

當我們在這種簡化分解下實證測量梯度(去掉裁剪指標 C_t)時,比率誤差和分數誤差向相反方向推動並在很大程度上抵消了:整體梯度方向保持在與乾淨梯度 cos > 0.95 的範圍內。透過這種衡量標準,訓練幾乎不會受到影響。事實並非如此。下一節重新測量了實際訓練梯度(恢復了裁剪指標),並得到了非常不同的答案。

深入探討 β 與干預措施

我們目前的進度:上一節表明,在簡化模型(無裁剪)下,整體梯度方向仍然出奇地接近乾淨梯度(cos > 0.95)。然而,幾個問題仍然懸而未決:β 如何隨著訓練進展而演變?所有 Token 受到的影響是否相同?最關鍵的是,當我們包含實際的 PPO 裁剪機制時,cos > 0.95 的發現是否成立?

罕見 Token 具有數量級更大的 |β|

最揭示性的發現:具有非常負的 log_probs(罕見、低機率)的 Token 具有比常見 Token 大得多的 |β|。在第 50 步:

  • 常見 Token (log_prob > -5):|β| ≈ 0.02

  • 中等 Token (log_prob ~ -10):|β| ≈ 0.1

  • 罕見 Token (log_prob < -20):|β| ≈ 0.5 到 1.0

這是 50 倍的不匹配幅度差異。對數機率誤差為 β_v = δz_v - δ(logsumexp),其中 logsumexp 由高機率 Token 主導,因此 δ(logsumexp) 相對穩定。對於常見 Token,誤差會抵消。對於罕見 Token,δz_v 可能與 δ(logsumexp) 非常不同,留下巨大的殘差。

幾何分解:訊號 vs 雜訊

第 5 節中的簡化分析使用自定義後向傳遞去掉了裁剪指標,發現 cos > 0.95。在這裡,我們測量優化器使用的「實際訓練梯度」,包括 PPO 的裁剪機制。

# 從正常訓練步驟中儲存損壞的梯度
corrupted_grads = {name: param.grad.float().clone()
                   for name, param in model.named_parameters() if param.grad is not None}

# 重新計算乾淨的 REINFORCE 損失(無權重採樣比率)
model.zero_grad()
clean_loss = -(advantages * log_probs * completion_mask).sum() / global_n
clean_loss.backward()

# 比較:所有參數的餘弦相似度和相對 L2 誤差
for name, param in model.named_parameters():
    g_corrupt = corrupted_grads[name]
    g_clean = param.grad.float()
    overall_cos_num += (g_corrupt * g_clean).sum()
    overall_cos_den_a += (g_corrupt * g_corrupt).sum()
    overall_cos_den_b += (g_clean * g_clean).sum()

*到第 10 步時,雜訊組分超過了訊號組分(81% vs 58%)。這與第 5 節的簡化分析相矛盾,後者發現 cos > 0.95。從 cos > 0.95 到 cos ≈ 0.55 的劇烈下降告訴我們,裁剪機制的某些方面以簡化分析未捕捉到的方式與 β 互動。

我們已經針對 β 建立了一個令人信服的間接證據案例。但相關性並不等於因果關係。為了給 β 定罪,我們需要一個對照實驗:一個將比率從梯度中隔離出來並分別測試每個組分的實驗。比率中的 β 真的是原因嗎?還是 FP32 梯度方向本身就阻止了學習?

為了確定比率中的 β 污染是否是失敗的主要原因,或者 FP32 梯度方向是否獨立地阻止了學習,我們對失敗的配置進行了兩項干預:

  • 執行 A(基準):BF16=True,無干預。參考收斂執行。

  • 執行 B(失敗):BF16=False,無干預。參考失敗執行。

  • 執行 F (ratio_one):BF16=False,但權重採樣比率被強制為 1。這將 GRPO 簡化為純 REINFORCE,完全從比率中移除了 α 和 β。

  • 執行 G* (ratio_bf16):BF16=False,但比率是根據 BF16 影子對數機率計算的,而不是 FP32 訓練前向傳遞。這僅從比率中移除了 β,同時保留了合法的滯後校正 α。

關鍵在於,執行 F 和 G 保留了用於梯度計算的 FP32 後向傳遞;只有比率被更改。這將比率效應與梯度方向效應隔離開來。

執行 A/B/F/G 的每步部署改進。執行 F 和 G 顯示持續的正向改進;執行 B 在零附近震盪:

兩項干預都收斂了!從比率中移除 β 恢復了訓練,儘管梯度方向仍然是 FP32。執行 F 和 G 實現的部署改進比執行 B 高出 16 到 19 倍,比執行 A 高出 2.9 到 3.5 倍

當 FP32 梯度方向從比率污染中解放出來時,它實際上比 BF16 梯度更能有效地改進 BF16 策略。這徹底排除了 FP32 後向傳遞獨立阻止學習的假設。

KL 散度

一個重要的問題是這些干預措施是否以犧牲訓練穩定性為代價。PPO 的裁剪機制存在是為了強制執行信任區域(trust region),以防止策略偏離 Rollout 策略太遠。執行 F(ratio=1)完全繞過了這一點,簡化為沒有信任區域約束的純 REINFORCE。執行 G(ratio_BF16)透過 α 保留了信任區域,但具有乾淨的比率。追蹤當前策略與 Rollout 策略之間的 KL 散度告訴我們每個執行偏離行為策略的激進程度。

執行 F 學習激進,KL 達到 8.5(無 PPO 裁剪約束)。執行 G 具有中等 KL,與執行 A 相似。BF16 影子比率提供了正確的權重採樣和裁剪。

我們已經確認了「是什麼」(從比率中移除 β 修復了訓練),但還沒有確認「如何」。第 5 節的簡化梯度分析預測乾淨梯度與損壞梯度之間的 cos > 0.95,但包含 PPO 裁剪的實際梯度顯示 cos 僅為 0.55。裁剪機制的某些方面以我們尚未解釋的方式與 β 互動。下一節重點在於隔離確切的機制。

機制:幻影裁剪

我們目前的進度與未解釋的部分

第 6 節(干預措施)確立了比率中的 β 是必要原因,但沒有說明它是「如何」破壞訓練的。工作假設是乘法優勢失真:e^(β_t) 對梯度進行加權,梯度失去了對比度。然而,當我們觀察包含 PPO 裁剪的實際 β 梯度影響時(第 6.2 節),餘弦相似度從 0.95 下降到 0.55,這與第 5 節的簡化分析存在巨大差異。這指向了與裁剪機制的互動,而簡化分析完全忽略了這一點。

損失結構實驗

為了隔離裁剪互動,我們在保持 β 不變的情況下測試了四種損失變體:

標準 PPO(基準,失敗):β 流經比率幅度和 min/clamp 裁剪決策。

clipped = torch.clamp(ratio, 1 - eps, 1 + eps)
per_token_loss = -torch.min(ratio * advantages, clipped * advantages)

Detach + center 和 Detach only:梯度權重從計算圖中分離,消除了來自 min/clamp 的零梯度死區。比較兩者測試了中心化(修正 μ_W 偏差)或分離(移除死區)哪一個才是重點。

W = torch.min(ratio * advantages, clipped * advantages)
mu_W = (W * completion_mask).sum() / n_valid
W_centered = W - mu_W
per_token_loss = -W_centered.detach() * log_probs

No-clip(ε = 10):標準 PPO,ε 大到沒有 Token 會觸及裁剪邊界。β 像在失敗的基準中一樣,即時流經比率和梯度。唯一的區別是 clamp 從未飽和。

所有三項干預都收斂了!ε = 10 的結果最有資訊量:標準 PPO,β 在比率和梯度中完全保留,但因為禁用了裁剪,它收斂了。如果它收斂,那麼裁剪互動就是機制所在。

被推翻的假設:權重分佈偏差

損失結構實驗表明裁剪參與其中,但前幾節也確立了 β 系統性地偏差了有效優勢(corr(β, A) > 0)。如果乘法失真假設是正確的,這種偏差應該體現在每個 Token 的梯度權重中:對於 BF16=False,μ_W 應該更正(因為 e^(β_t) 膨脹了良好優勢 Token 並緊縮了不良優勢 Token)。我們對訓練器進行了儀器化以記錄 W_t = min(r_t(θ) A, clip(r_t(θ)) A) 並直接測試此預測。

乘法失真理論被推翻:μ_W ≈ -0.24 在所有執行中完全相同。權重分佈完全不受 β 影響。沒有不良 Token 被強化;沒有良好 Token 被抑制。

正確機制:幻影裁剪

理解失敗的關鍵在於 PPO 的裁剪機制。當 torch.min 選擇裁剪分支時,torch.clamp 產生零梯度,因為其輸出是常數。裁剪決策取決於比率是否超過了信任區域:

當比率反映真實的策略變更時,邏輯是合理的:「如果策略已經為此 Token 移動了很多,就停止推動。」但當 r_t(θ) = e^(α_t + β_t) 時,裁剪決策使用損壞的對數比率。在訓練初期 α_t ≈ 0(策略幾乎沒有改變),因此裁剪決策簡化為一個簡單的問題:|β_t| > 0.2 嗎?

考慮一個具體例子。一個 Token 的 α ≈ 0 但 β = 0.25。損壞的比率 r = e^(0.25) = 1.28 > 1.2。PPO 斷定此 Token 已經改進了 28%,並關閉了其梯度。實際上,該 Token 根本沒有移動。這 28% 的「改進」純粹是精度雜訊。

幻影裁剪:PPO 從 β 中看到了幻影策略移動,並為那些仍然需要學習的 Token 歸零了梯度。

這就是我們一直在尋找的機制。不是梯度方向損壞(參見第 5 節關於無裁剪的分析),不是乘法優勢失真(第 7.2 節,μ_W 在執行間相同),而是對優化器仍然需要從中學習的 Token 進行二元、全有或全無的消除。裁剪指標 C_t(簡化分析忽略了它)正是 β 造成真正損害的地方。

我們可以量化有多少 Token 受到影響。從 BF16=False 的執行來看,穩定狀態下 β ~ N(-0.01, 0.15),得出 P(|β| > 0.2) ≈ 18%。實證裁剪比率證實了這一預測:

在第 3 步,策略幾乎沒有移動(α ≈ 0),但 BF16=False 已經裁剪了 13.5% 的 Token,而 BF16=True 僅為 1.0%。額外的 12.5% 是幻影裁剪:其梯度純粹被精度雜訊而非真實策略變更所消除的 Token。

為了使這一點直接可見,我們透過比較實際比率 r_t(θ) = exp(α_t + β_t) 與乾淨比率 r_t^α(θ) = exp(α_t) 來對每個 Token 進行分類。如果一個 Token 在實際比率下落在裁剪邊界之外,但在乾淨比率下落在邊界之內,則它是幻影裁剪的。

*視覺化顯示每個 Token 根據其權重採樣比率定位。使用實際比率(包含 β)時,Token 散佈在裁剪邊界之外:在第 5 步,17.2% 的 Token 是幻影裁剪的,而只有 0.4% 是合法裁剪的。從比率中移除 β 後,僅使用 α 重新計算,幾乎所有 Token 都緊密聚集在 r = 1.0 周圍,位於信任區域內。到第 30 步,隨著策略開始移動,合法裁剪出現,但幻影裁剪(23.7%)仍然主導合法裁剪(9.5%)。

回到第 5.1 節的梯度分解,我們現在可以恢復之前被刪除的裁剪指標 C_t。這引入了第三個誤差項,捕捉幻影裁剪效應:

其中 Δg_clip = -(1/N) Σ_t A_t · r_t(θ) · s_t^(P) · (C_t^(β) - C_t^(0)) 捕捉了當 β 翻轉裁剪決策時獲得或損失的梯度訊號。在訓練初期(α ≈ 0),大約 13% 的 Token 完全失去了梯度。

三條證據證實 Δg_clip 是主要的失敗機制:

  • 移除裁剪(ε = 10)修復了收斂,同時保持 Δg_ratio 和 Δg_score 完全完整。來自 β 的乘法失真仍然存在於梯度中,但模型收斂了。

  • 權重分佈 μ_W 不受 β 影響,完全排除了乘法優勢失真通道(第 7.2 節)。

  • 分離(Detach)透過後向路徑中移除 min/clamp 的不同途徑消除了 Δg_clip,也恢復了收斂。

部署改進

我們現在觀察損失結構執行中的部署改進:

No-clip 執行恢復到 7.8% 的效率,與 BF16=True 的 7.9% 相匹配,儘管 β_abs_mean 高達 1.2。

關鍵結論:來自 β 的乘法失真是可以容忍的。幻影裁剪則不然。

比較修復方案

所有成功的修復方案都有一個共同點:它們防止了 β 產生零梯度死區。

正確的機制是幻影裁剪:β_t 將權重採樣比率推過了 PPO 的裁剪邊界,針對那些策略實際上並未改變的 Token,觸發了 torch.clamp 飽和,並為這些 Token 產生了精確的零梯度。

結論

TL;DR

  • 根本原因:訓練前向傳遞與 vLLM 推論伺服器之間的 BF16 精度不匹配產生了精度差距 β,該差距進入了權重採樣比率。

  • 失敗機制:β 將比率推過了 PPO 的裁剪邊界,針對那些策略實際上並未改變的 Token,消除了約 18% 的梯度訊號。

  • 修復:匹配精度(到處使用 FP16,或 BF16 autocast),或從策略比率中移除 β。

根本原因

非同步 GRPO 訓練在訓練前向傳遞(FP32)與 vLLM 推論伺服器(BF16)使用不同數值精度時會失敗。精度差距 β_t = f^(fp32)(a_t; W) - f^(bf16)(a_t; W) 進入權重採樣比率 r_t(θ) = exp(α_t + β_t) 並觸發幻影 PPO 裁剪:優化器將那些策略實際上並未改變的 Token 的梯度訊號歸零。在 Qwen3-0.6B 的受控立即輸出 EOS 任務上,此機制在學習率 10^-6 下完全阻止了收斂,而匹配精度的訓練在 100 步內收斂。

精度差距不僅僅是數值雜訊。它源於 28 個 Transformer 層累積的捨入差異,產生了平均 |β| 為 0.076,尾部達到 3.05。該差距與 Token 相關(罕見 Token 具有 50 倍大的 |β|),與優勢訊號系統性相關(Cov(A, β) > 0),且大到足以將大約 18% 的 Token 推過 PPO 的裁剪邊界(ε = 0.2)。在訓練初期,當策略幾乎沒有移動(α ≈ 0)時,這些被幻影裁剪的 Token 儘管確實包含有用的學習訊號,卻收到了精確的零梯度。由此導致的每步部署改進減少 7 倍,結合 RL 反饋迴路,將系統鎖定在永久停滯狀態。

我們排除了什麼

一個最初合理的假設認為 β 透過乘法優勢失真(A_t^eff = A_t · e^(β_t))破壞訓練,這會壓縮有效優勢分佈並破壞梯度對比度。我們仔細測量了這一點並最終推翻了它:無論 β 如何,每個 Token 的梯度權重分佈在所有執行中都是相同的。決定性的實驗是設置 ε = 10(禁用裁剪),同時讓 β 在比率和梯度中保持完整。此執行收斂到 7.8% 的部署改進效率,與 BF16=True 的 7.9% 相匹配。乘法失真是可以容忍的;幻影裁剪則不然。

為什麼是 RL?

這種失敗模式是 RL 特有的。在預訓練和微調中,β 以加法方式進入交叉熵損失,產生平均值約為零且保持方向的梯度雜訊(完整版本中有詳細分析)。在 RL 中,權重採樣比率中的 exp() 將此加法誤差轉換為與 PPO 裁剪機制產生破壞性互動的乘法擾動。

此失敗的三個條件。這三個條件必須同時發生:

  1. 跨系統比率:權重採樣比率耦合了不同精度下的計算(訓練 vs 推論)。

  2. 裁剪代理損失:PPO 的裁剪產生了 β 可以觸發的零梯度死區。

  3. 閉環資料:訓練資料取決於已部署的模型,因此退化的更新會隨時間累積。

建議

按從最強到最方便排序:

  1. FP16 訓練與 FP16 推論。這是當你的硬體和框架支援時的最佳選擇。FP16 具有 10 個尾數位元(相對於 BF16 的 7 個),提供顯著更好的數值穩定性,同時仍受益於硬體加速的矩陣乘法。由於訓練和推論都在 FP16 中,精度不匹配在結構上為零。我們第 1 節的收斂表證實了這一點:匹配 vLLM 的 FP16 在 lr = 10^-6 下乾淨地收斂。

  2. BF16=True 與 FP32 主權重(master weights)。這是大多數 LLM 訓練框架使用的標準混合精度配方,也是我們的預設建議。Autocast 將訓練前向傳遞與 vLLM 的 BF16 相匹配,產生 β ≈ 0。FP32 主權重確保優化器以全精度累積更新。這是最安全且支援最廣泛的選項。

  3. ratio_BF16(影子前向傳遞)。當 FP16 和 BF16 autocast 均不可用時,從 BF16 影子前向傳遞計算權重採樣比率,而不是 FP32 訓練前向傳遞。這從比率中移除了 β,同時保留了 FP32 梯度,當從比率污染中解放出來時,它實際上比 BF16 梯度稍微有效一些(正如我們的干預實驗所顯示的那樣)。代價是每個訓練步驟多進行一次前向傳遞。

  4. 禁用裁剪(ε = 10)。設置足夠大的 ε,使沒有 Token 達到裁剪邊界,可以在零成本下消除幻影裁剪。β 仍然存在於比率和梯度中,但僅乘法失真是可以容忍的。在我們的簡單任務上這效果很好;在具有獎勵駭客(reward hacking)或分佈偏移的更難任務上,缺乏信任區域可能會引入不穩定性。

  5. 分離梯度權重。從後向路徑中移除 min/clamp 可以透過不同的途徑消除零梯度死區。這有效,但會產生高 KL 散度(高達 15.9),在實務上是最不穩定的選項。