← 返回首頁

MoE 模型中的訓練與推論一致性:數值漂移發生之處

Fireworks AI
Fireworks AI
@FireworksAI_HQ
178🔁 17
𝕏 (Twitter)🔥

MoE 模型中的訓練與推論一致性:數值漂移發生之處

當「更快」不等於「相同」:部署 MoE 模型時的數值陷阱

在數學上等價的 Kernel 融合(Kernel fusions),在數值上仍可能產生漂移。以下是我們在 Kimi K2.5 推論服務與 Qwen3.5-MoE 訓練啟動過程中遇到的一致性錯誤(Parity bugs)。

為什麼這很重要

當你訓練一個模型並將其部署用於推論時,你會期望兩者表現一致:相同的權重、相同的輸入、相同的輸出分佈。這種訓練與推論之間的數值一致性(Numerical parity)比聽起來更重要:

  • RLHF / GRPO 獎勵完整性:參考模型(Reference model)的 logprobs 是 KL 懲罰的錨點。如果推論產生的 logprobs 與訓練時針對相同權重產生的結果不同,策略(Policy)可能會在不實際改進的情況下利用這個差距。

  • 可重現性:由 Kernel 融合引起的數值漂移是隱形的——權重相同,架構看起來也一樣,但輸出卻出現分歧。

  • 客戶信任:使用者在我們的平台上進行微調,並期望部署後的結果與訓練優化的目標相符。

對於稠密模型(Dense models),達成一致性相對容易。但像 Kimi K2.5、Qwen3.5-MoE 和 DeepSeek V3 這類 Mixture-of-Experts (MoE) 模型則更困難。由於存在路由專家(Routed experts)、共享專家路徑,以及在深層堆疊中每層兩次的 all-reduce 通訊,許多「數學上等價」的優化手段都會產生數值上的差異。

這篇文章整理了我們發現的陷阱。每一項都是推論引擎為了效能而採用的優化類別,但它們可能會悄悄破壞數值對齊。我們在將 Kimi K2.5 導入我們的推論堆疊時發現了大部分問題,隨後在除錯 Qwen3.5-MoE 時又看到了相同的失效模式。我們將以 FlashInfer 和 TRT-LLM 風格的融合 Kernel 作為具體範例。

根本問題:浮點數加法不具結合律

以下提到的每個陷阱都歸結為一個事實:浮點數加法不具結合律。即使在 FP32 中也是如此:

(a + b) + c ≠ a + (b + c)

每次加法都會將結果四捨五入到最接近的可表示值。不同的運算順序會產生不同的中間值,進而導致不同的捨入誤差。這些誤差在單次運算中微乎其微,但會在 61 層 Transformer 層中累積——而 MoE 路由會放大這些誤差(隱藏狀態的微小變化可能會改變專家選擇的結果,並在網路的其餘部分產生連鎖反應)。

陷阱 1:All-Reduce 拓撲差異

它是什麼

在張量並行(Tensor-parallel)推論中,每個線性層的輸出必須透過 all-reduce 在各個 GPU 之間進行加總。這在每層中會發生兩次:在注意力輸出投影之後,以及在 MLP/MoE 之後。

訓練通常使用 NCCL,它將 all-reduce 實作為 reduce-scatter 後接 all-gather。在環狀拓撲(Ring topology)的 reduce-scatter 階段,資料被分成多個區塊(每個 GPU 一個)。隨著部分和(Partial sums)在環中流動,每個區塊會從「擁有」該區塊的 GPU 開始進行累加。對於 8 個 GPU,這意味著隱藏向量的不同部分會看到不同的加總順序:

NCCL ring reduce-scatter (8 GPUs):
  chunk 0 (owned by GPU0): r0 + r7 + r6 + r5 + r4 + r3 + r2 + r1
  chunk 1 (owned by GPU1): r1 + r0 + r7 + r6 + r5 + r4 + r3 + r2
  chunk 2 (owned by GPU2): r2 + r1 + r0 + r7 + r6 + r5 + r4 + r3
  ...each chunk starts from its owner, accumulates around the ring

推論服務引擎通常會為了降低延遲,將 NCCL 替換為自定義的 all-reduce Kernel。FlashInfer 的 Lamport IPC Kernel(源自 TRT-LLM)採用了不同的方法:每個 GPU 透過 CUDA IPC 將資料寫入所有其他 GPU 的緩衝區,然後每個 GPU 在本地讀取所有貢獻值並以固定順序加總:

Lamport kernel (all elements, on every GPU):
  every chunk:              r0 + r1 + r2 + r3 + r4 + r5 + r6 + r7

兩者都在 FP32 中進行累加。兩者在精確算術下都會產生正確的總和。但 NCCL 中的區塊旋轉意味著隱藏向量的不同元素會看到不同的加法順序,而 Lamport Kernel 對所有內容都使用統一的 r0..r7 順序。由於浮點數加法不具結合律,這些操作會產生不同的結果。

為什麼容易被忽略

輸出看起來是正確的。模型生成了連貫的文字。分歧在 KL 散度上小於 0.001。只有當你逐個 token 對比 logprobs 與參考值時才會發現——而這正是 RLHF 獎勵計算所做的事情。


陷阱 2:通訊與計算的融合

它是什麼

在未融合的路徑中,all-reduce 和 RMSNorm 是兩個獨立的 Kernel 啟動,中間夾著一次 HBM 來回傳輸:

融合路徑將它們合併為單一 Kernel。all-reduce 的結果保留在暫存器中並直接流入正規化(Normalization)——沒有 HBM 來回傳輸,也沒有第二次 Kernel 啟動。

full_out = all_reduce(partial_out)              # kernel 1: writes result to HBM
normed, residual = rmsnorm(full_out, residual)  # kernel 2: reads from HBM

效能提升非常顯著(每次運算節省約 3 TB/s 的 HBM 頻寬)。但融合後的 Kernel 在計算 RMSNorm 時,其執行緒佈局(Thread layout)與獨立的 Norm Kernel 不同。

要了解為什麼這很重要,請考慮 RMSNorm 如何計算隱藏維度上的平方和。隱藏狀態分佈在各個執行緒中,部分和必須縮減為單一純量。GPU Kernel 使用蝶形縮減(Butterfly reduction)來完成此操作——每個執行緒透過 __shfl_xor_sync 與夥伴交換數值,在每一步將參與者的數量減半:

這是在每個 32 執行緒 Warp 內部的 5 步二元樹。之後,Warp 層級的結果透過共享記憶體(Shared memory)進行合併(在 Hopper 架構上,則是透過叢集共享記憶體在區塊間合併)。最終數值被饋送到 rsqrtf 以產生正規化比例。

Step 1: thread 0 ↔ thread 16, thread 1 ↔ thread 17, ...  (mask=16)
Step 2: thread 0 ↔ thread 8,  thread 1 ↔ thread 9,  ...  (mask=8)
Step 3: thread 0 ↔ thread 4,  thread 1 ↔ thread 5,  ...  (mask=4)
Step 4: thread 0 ↔ thread 2,  thread 1 ↔ thread 3,  ...  (mask=2)
Step 5: thread 0 ↔ thread 1                                (mask=1)

關鍵洞察:不同的區塊大小意味著不同的元素會落在不同的執行緒和 Warp 中,因此蝶形運算在每一步配對的部分和也不同。中間的加法順序發生變化,產生了不同的 rsqrtf 輸入——這隨後會以不同的方式縮放隱藏狀態的每個元素。融合 Kernel 的區塊大小是由 all-reduce 的需求決定的,而不是由 RMSNorm 單獨的最佳化需求決定的。(實作細節請參見 trtllm_allreduce_fusion.cuh 中的 blockReduceSumV2。)

對於 DeepSeek V3 / Kimi K2.5,這種融合在注意力路徑的所有 61 層上執行,並且在 MoE 路徑上更加激進(陷阱 3)。

為什麼容易被忽略

與陷阱 1 相同——融合後的 Kernel 執行的是「相同的數學運算」。分歧只會在仔細比對 logprob 時顯現。


陷阱 3:MoE 中的多重運算融合

它是什麼

MoE 層比稠密層有更多的運算需要融合。在每個 MoE 區塊的末尾,FlashInfer 的 MoE 融合 Kernel 將三個運算合併為一個:

  1. MoE finalize — 每個 token 的 Top-8 專家輸出加權和

  2. All-reduce — 在 GPU 之間加總部分結果(Lamport IPC)

  3. 下一個區塊的輸入 RMSNorm — 正規化並加上下一個 Transformer 區塊的殘差(Residual)

在未融合的路徑中,每個運算都是一個擁有獨立執行緒佈局的獨立 Kernel:

expert_out = moe_finalize(expert_outputs, weights)    # kernel 1
expert_out += shared_expert(x)                        # kernel 2
full_out = nccl_all_reduce(expert_out)                # kernel 3
normed = rmsnorm(full_out + residual)                 # kernel 4

融合後的 Kernel 一次完成所有這些操作。每個運算都使用由整體 Kernel 設計決定的執行緒佈局,而不是每個運算獨立選擇的佈局。

這在 58 個 MoE 層(第 3-60 層)上執行,分歧會不斷累積:第 3 層輸出的微小差異會透過後續所有層的注意力與 MoE 計算傳播。MoE 路由對此特別敏感——隱藏狀態的微小變化可能會改變 256 個專家中的選擇結果,從而產生連鎖反應。

測量影響

我們在 Kimi K2.5 上使用 25 個提示詞(每個生成 200 個 token)測量了分歧,比較了參考組(關閉所有融合)與各種配置之間的 logprob 分佈。我們的指標是 k3,這是一種始終非負且穩定的 KL 散度變體:

基準線(k3 = 0.000070)是雜訊底噪。僅 MLP 權重串接融合一項就使 k3 提高了約 2.7 倍。所有配置都通過了 k3 < 0.001 的閾值,但對於 RLHF/GRPO 而言——推論引擎即是參考策略——將 k3 最小化至接近底噪是非常值得的。


案例研究:Qwen3.5-MoE 影像 Token 漂移

我們在啟動使用 DeepEP 專家並行(Expert parallelism)的 Qwen3.5-MoE 訓練時,再次看到了同類問題。文字 token 的 k3 保持相對較小,但影像 token 的 k3 在 bf16 中出現了劇烈分歧。模型權重沒問題;聚合路徑有問題。

在此對比中,Hugging Face 的參考路徑是官方的 Qwen/Qwen3.5-397B-A17B 模型,在 Transformers 中實作為 Qwen3_5MoeForConditionalGeneration

我們如何隔離問題

我們將每一層的 MoE 區塊替換為 Hugging Face 的參考實作,同時保持堆疊的其餘部分不變:DeltaNet、GatedAttention、融合編碼器、嵌入(Embeddings)和 Norms 都保留在 Fireworks 的路徑上。

該替換使兩個指標都歸零,這告訴我們分歧完全存在於 MoE 聚合內部。

我們還進行了逐層縮減測試:將 Hugging Face 的隱藏狀態獨立饋送到我們的每一層中,然後逐層比較輸出。每一層在隔離狀態下實際上都是乾淨的。影像 token 的 k3 僅在經過約 40 層微小的文字 token 誤差透過稠密雙向注意力累積後才出現。

精確的數值不匹配

Hugging Face 的參考路徑將每個專家輸出乘以 float32 的路由權重,將每個專家貢獻值轉型(Cast)為 bf16,並透過 index_add_ 在 bf16 中進行累加。

# HF reference path
scored = expert_output_bf16 * score_float32
final.index_add_(0, token_idx, scored.to(torch.bfloat16))

我們標準的 Fireworks 路徑則是將所有八個專家貢獻值保留在 float32 中,執行批次加權和,最後才進行一次轉型。

# Fireworks standard path
out = torch.bmm(scores_f32, all_expert_outputs.float()).to(torch.bfloat16)

DeepEP 再次擴大了差距,因為 combine_tokens 在乘法前將路由分數從 float32 轉型為 bf16,並且仍然在 bf16 中進行加總。

這與本文其餘部分的教訓相同,但在更痛苦的情境中:數學上等價的 Kernel 一旦誤差累積,數值差異就足以產生影響。

經驗教訓

  1. 「相同的數學」並不代表「相同的位元」。上述每個陷阱在數學上都與參考路徑等價。分歧純粹來自不同的浮點數累加順序——無論是在 all-reduce 拓撲、Warp-shuffle 縮減樹,還是在 cuBLAS 的分塊啟發式演算法中。即使使用 FP32 累加,這點依然成立。

  2. MoE 模型特別脆弱。路由器的 Top-k 選擇意味著隱藏狀態的微小變化可能會改變專家指派,從而在後續層中產生連鎖反應。稠密模型沒有這種放大機制。

  3. 使用正確的指標進行測量。我們使用 k3(一種 KL 散度變體),閾值為 0.001。如果沒有定量測量,「模型生成合理的文字」是你所能做的最好情況——但這對於 RLHF 來說是不夠的。

  4. 給予使用者細粒度的控制。單一的「停用所有優化」旗標太過粗糙。RLHF 使用者需要保真度;推論使用者需要吞吐量。透過每個融合項的旗標,讓每個工作負載都能做出選擇。

  5. 複合效應佔主導地位。沒有單一陷阱會導致巨大的分歧。但是 61 層的 all-reduce 拓撲差異 + 58 層的 MoE finalize 融合 + 每個 MLP 中的 cuBLAS 分塊差異——這些微小的層級誤差加總起來就非常可觀。


參考資料

  • FlashInfer trtllm_allreduce_fusion.cuh — 融合 all-reduce + RMSNorm Kernel

  • FlashInfer trtllm_moe_allreduce_fusion.cuh — 融合 MoE finalize + all-reduce + RMSNorm Kernel

  • Qwen3.5-MoE 參考實作 — 用於 Qwen/Qwen3.5-397B-A17B 案例研究的 Hugging Face Transformers 參考路徑實作

  • 本文最初發表於 Fireworks AI 部落格,你可以在該處找到我們 AI 研究團隊發布的類似技術文章。