# 策展 · X (Twitter) 🔥

> 📖 本站完整內容索引（documentation index）：[llms.txt](/llms.txt)

> 作者：Fireworks AI (@FireworksAI_HQ) · 平台：X (Twitter) · 日期：2026-04-18

> 原始來源：https://x.com/fireworksai_hq/status/2045366426819768794

## 中文摘要

# MoE 模型中的訓練與推論一致性：數值漂移發生之處

# 當「更快」不等於「相同」：部署 MoE 模型時的數值陷阱

在數學上等價的 Kernel 融合（Kernel fusions），在數值上仍可能產生漂移。以下是我們在 Kimi K2.5 推論服務與 Qwen3.5-MoE 訓練啟動過程中遇到的一致性錯誤（Parity bugs）。

## 為什麼這很重要

當你訓練一個模型並將其部署用於推論時，你會期望兩者表現一致：相同的權重、相同的輸入、相同的輸出分佈。這種訓練與推論之間的數值一致性（Numerical parity）比聽起來更重要：

- **RLHF / GRPO 獎勵完整性**：參考模型（Reference model）的 logprobs 是 KL 懲罰的錨點。如果推論產生的 logprobs 與訓練時針對相同權重產生的結果不同，策略（Policy）可能會在不實際改進的情況下利用這個差距。

- **可重現性**：由 Kernel 融合引起的數值漂移是隱形的——權重相同，架構看起來也一樣，但輸出卻出現分歧。

- **客戶信任**：使用者在我們的平台上進行微調，並期望部署後的結果與訓練優化的目標相符。

對於稠密模型（Dense models），達成一致性相對容易。但像 Kimi K2.5、Qwen3.5-MoE 和 DeepSeek V3 這類 Mixture-of-Experts (MoE) 模型則更困難。由於存在路由專家（Routed experts）、共享專家路徑，以及在深層堆疊中每層兩次的 all-reduce 通訊，許多「數學上等價」的優化手段都會產生數值上的差異。

這篇文章整理了我們發現的陷阱。每一項都是推論引擎為了效能而採用的優化類別，但它們可能會悄悄破壞數值對齊。我們在將 Kimi K2.5 導入我們的推論堆疊時發現了大部分問題，隨後在除錯 Qwen3.5-MoE 時又看到了相同的失效模式。我們將以 FlashInfer 和 TRT-LLM 風格的融合 Kernel 作為具體範例。

## 根本問題：浮點數加法不具結合律

以下提到的每個陷阱都歸結為一個事實：浮點數加法不具結合律。即使在 FP32 中也是如此：

(a + b) + c  ≠  a + (b + c)

每次加法都會將結果四捨五入到最接近的可表示值。不同的運算順序會產生不同的中間值，進而導致不同的捨入誤差。這些誤差在單次運算中微乎其微，但會在 61 層 Transformer 層中累積——而 MoE 路由會放大這些誤差（隱藏狀態的微小變化可能會改變專家選擇的結果，並在網路的其餘部分產生連鎖反應）。

# 陷阱 1：All-Reduce 拓撲差異

## 它是什麼

在張量並行（Tensor-parallel）推論中，每個線性層的輸出必須透過 all-reduce 在各個 GPU 之間進行加總。這在每層中會發生兩次：在注意力輸出投影之後，以及在 MLP/MoE 之後。

訓練通常使用 NCCL，它將 all-reduce 實作為 reduce-scatter 後接 all-gather。在環狀拓撲（Ring topology）的 reduce-scatter 階段，資料被分成多個區塊（每個 GPU 一個）。隨著部分和（Partial sums）在環中流動，每個區塊會從「擁有」該區塊的 GPU 開始進行累加。對於 8 個 GPU，這意味著隱藏向量的不同部分會看到不同的加總順序：

```
NCCL ring reduce-scatter (8 GPUs):
  chunk 0 (owned by GPU0): r0 + r7 + r6 + r5 + r4 + r3 + r2 + r1
  chunk 1 (owned by GPU1): r1 + r0 + r7 + r6 + r5 + r4 + r3 + r2
  chunk 2 (owned by GPU2): r2 + r1 + r0 + r7 + r6 + r5 + r4 + r3
  ...each chunk starts from its owner, accumulates around the ring
```

推論服務引擎通常會為了降低延遲，將 NCCL 替換為自定義的 all-reduce Kernel。FlashInfer 的 Lamport IPC Kernel（源自 TRT-LLM）採用了不同的方法：每個 GPU 透過 CUDA IPC 將資料寫入所有其他 GPU 的緩衝區，然後每個 GPU 在本地讀取所有貢獻值並以固定順序加總：

```
Lamport kernel (all elements, on every GPU):
  every chunk:              r0 + r1 + r2 + r3 + r4 + r5 + r6 + r7
```

兩者都在 FP32 中進行累加。兩者在精確算術下都會產生正確的總和。但 NCCL 中的區塊旋轉意味著隱藏向量的不同元素會看到不同的加法順序，而 Lamport Kernel 對所有內容都使用統一的 r0..r7 順序。由於浮點數加法不具結合律，這些操作會產生不同的結果。

為什麼容易被忽略

輸出看起來是正確的。模型生成了連貫的文字。分歧在 KL 散度上小於 0.001。只有當你逐個 token 對比 logprobs 與參考值時才會發現——而這正是 RLHF 獎勵計算所做的事情。

---

# 陷阱 2：通訊與計算的融合

## 它是什麼

在未融合的路徑中，all-reduce 和 RMSNorm 是兩個獨立的 Kernel 啟動，中間夾著一次 HBM 來回傳輸：

融合路徑將它們合併為單一 Kernel。all-reduce 的結果保留在暫存器中並直接流入正規化（Normalization）——沒有 HBM 來回傳輸，也沒有第二次 Kernel 啟動。

```
full_out = all_reduce(partial_out)              # kernel 1: writes result to HBM
normed, residual = rmsnorm(full_out, residual)  # kernel 2: reads from HBM
```

效能提升非常顯著（每次運算節省約 3 TB/s 的 HBM 頻寬）。但融合後的 Kernel 在計算 RMSNorm 時，其執行緒佈局（Thread layout）與獨立的 Norm Kernel 不同。

要了解為什麼這很重要，請考慮 RMSNorm 如何計算隱藏維度上的平方和。隱藏狀態分佈在各個執行緒中，部分和必須縮減為單一純量。GPU Kernel 使用蝶形縮減（Butterfly reduction）來完成此操作——每個執行緒透過 `__shfl_xor_sync` 與夥伴交換數值，在每一步將參與者的數量減半：

這是在每個 32 執行緒 Warp 內部的 5 步二元樹。之後，Warp 層級的結果透過共享記憶體（Shared memory）進行合併（在 Hopper 架構上，則是透過叢集共享記憶體在區塊間合併）。最終數值被饋送到 `rsqrtf` 以產生正規化比例。

```
Step 1: thread 0 ↔ thread 16, thread 1 ↔ thread 17, ...  (mask=16)
Step 2: thread 0 ↔ thread 8,  thread 1 ↔ thread 9,  ...  (mask=8)
Step 3: thread 0 ↔ thread 4,  thread 1 ↔ thread 5,  ...  (mask=4)
Step 4: thread 0 ↔ thread 2,  thread 1 ↔ thread 3,  ...  (mask=2)
Step 5: thread 0 ↔ thread 1                                (mask=1)
```

關鍵洞察：不同的區塊大小意味著不同的元素會落在不同的執行緒和 Warp 中，因此蝶形運算在每一步配對的部分和也不同。中間的加法順序發生變化，產生了不同的 `rsqrtf` 輸入——這隨後會以不同的方式縮放隱藏狀態的每個元素。融合 Kernel 的區塊大小是由 all-reduce 的需求決定的，而不是由 RMSNorm 單獨的最佳化需求決定的。（實作細節請參見 `trtllm_allreduce_fusion.cuh` 中的 `blockReduceSumV2`。）

對於 DeepSeek V3 / Kimi K2.5，這種融合在注意力路徑的所有 61 層上執行，並且在 MoE 路徑上更加激進（陷阱 3）。

為什麼容易被忽略

與陷阱 1 相同——融合後的 Kernel 執行的是「相同的數學運算」。分歧只會在仔細比對 logprob 時顯現。

---

# 陷阱 3：MoE 中的多重運算融合

## 它是什麼

MoE 層比稠密層有更多的運算需要融合。在每個 MoE 區塊的末尾，FlashInfer 的 MoE 融合 Kernel 將三個運算合併為一個：

1. MoE finalize — 每個 token 的 Top-8 專家輸出加權和

1. All-reduce — 在 GPU 之間加總部分結果（Lamport IPC）

1. 下一個區塊的輸入 RMSNorm — 正規化並加上下一個 Transformer 區塊的殘差（Residual）

在未融合的路徑中，每個運算都是一個擁有獨立執行緒佈局的獨立 Kernel：

```
expert_out = moe_finalize(expert_outputs, weights)    # kernel 1
expert_out += shared_expert(x)                        # kernel 2
full_out = nccl_all_reduce(expert_out)                # kernel 3
normed = rmsnorm(full_out + residual)                 # kernel 4
```

融合後的 Kernel 一次完成所有這些操作。每個運算都使用由整體 Kernel 設計決定的執行緒佈局，而不是每個運算獨立選擇的佈局。

這在 58 個 MoE 層（第 3-60 層）上執行，分歧會不斷累積：第 3 層輸出的微小差異會透過後續所有層的注意力與 MoE 計算傳播。MoE 路由對此特別敏感——隱藏狀態的微小變化可能會改變 256 個專家中的選擇結果，從而產生連鎖反應。

## 測量影響

我們在 Kimi K2.5 上使用 25 個提示詞（每個生成 200 個 token）測量了分歧，比較了參考組（關閉所有融合）與各種配置之間的 logprob 分佈。我們的指標是 k3，這是一種始終非負且穩定的 KL 散度變體：

![](https://pub-75d4fe1e4e80421b9ecb1245a7ae0d1a.r2.dev/curated/1776543719534-iaHGKVRA8a8AAUtSQpng.png)

基準線（k3 = 0.000070）是雜訊底噪。僅 MLP 權重串接融合一項就使 k3 提高了約 2.7 倍。所有配置都通過了 k3 < 0.001 的閾值，但對於 RLHF/GRPO 而言——推論引擎即是參考策略——將 k3 最小化至接近底噪是非常值得的。

---

# 案例研究：Qwen3.5-MoE 影像 Token 漂移

我們在啟動使用 DeepEP 專家並行（Expert parallelism）的 Qwen3.5-MoE 訓練時，再次看到了同類問題。文字 token 的 k3 保持相對較小，但影像 token 的 k3 在 bf16 中出現了劇烈分歧。模型權重沒問題；聚合路徑有問題。

在此對比中，Hugging Face 的參考路徑是官方的 `Qwen/Qwen3.5-397B-A17B` 模型，在 Transformers 中實作為 `Qwen3_5MoeForConditionalGeneration`。

## 我們如何隔離問題

我們將每一層的 MoE 區塊替換為 Hugging Face 的參考實作，同時保持堆疊的其餘部分不變：DeltaNet、GatedAttention、融合編碼器、嵌入（Embeddings）和 Norms 都保留在 Fireworks 的路徑上。

![](https://pub-75d4fe1e4e80421b9ecb1245a7ae0d1a.r2.dev/curated/1776543719526-diaHGKVvY2bwAAAj8png.png)

該替換使兩個指標都歸零，這告訴我們分歧完全存在於 MoE 聚合內部。

我們還進行了逐層縮減測試：將 Hugging Face 的隱藏狀態獨立饋送到我們的每一層中，然後逐層比較輸出。每一層在隔離狀態下實際上都是乾淨的。影像 token 的 k3 僅在經過約 40 層微小的文字 token 誤差透過稠密雙向注意力累積後才出現。

## 精確的數值不匹配

Hugging Face 的參考路徑將每個專家輸出乘以 float32 的路由權重，將每個專家貢獻值轉型（Cast）為 bf16，並透過 `index_add_` 在 bf16 中進行累加。

```
# HF reference path
scored = expert_output_bf16 * score_float32
final.index_add_(0, token_idx, scored.to(torch.bfloat16))
```

我們標準的 Fireworks 路徑則是將所有八個專家貢獻值保留在 float32 中，執行批次加權和，最後才進行一次轉型。

```
# Fireworks standard path
out = torch.bmm(scores_f32, all_expert_outputs.float()).to(torch.bfloat16)
```

DeepEP 再次擴大了差距，因為 `combine_tokens` 在乘法前將路由分數從 float32 轉型為 bf16，並且仍然在 bf16 中進行加總。

![](https://pub-75d4fe1e4e80421b9ecb1245a7ae0d1a.r2.dev/curated/1776543719529-diaHGKWE2FaIAArSQpng.png)

這與本文其餘部分的教訓相同，但在更痛苦的情境中：數學上等價的 Kernel 一旦誤差累積，數值差異就足以產生影響。

## 經驗教訓

1. 「相同的數學」並不代表「相同的位元」。上述每個陷阱在數學上都與參考路徑等價。分歧純粹來自不同的浮點數累加順序——無論是在 all-reduce 拓撲、Warp-shuffle 縮減樹，還是在 cuBLAS 的分塊啟發式演算法中。即使使用 FP32 累加，這點依然成立。

1. MoE 模型特別脆弱。路由器的 Top-k 選擇意味著隱藏狀態的微小變化可能會改變專家指派，從而在後續層中產生連鎖反應。稠密模型沒有這種放大機制。

1. 使用正確的指標進行測量。我們使用 k3（一種 KL 散度變體），閾值為 0.001。如果沒有定量測量，「模型生成合理的文字」是你所能做的最好情況——但這對於 RLHF 來說是不夠的。

1. 給予使用者細粒度的控制。單一的「停用所有優化」旗標太過粗糙。RLHF 使用者需要保真度；推論使用者需要吞吐量。透過每個融合項的旗標，讓每個工作負載都能做出選擇。

1. 複合效應佔主導地位。沒有單一陷阱會導致巨大的分歧。但是 61 層的 all-reduce 拓撲差異 + 58 層的 MoE finalize 融合 + 每個 MLP 中的 cuBLAS 分塊差異——這些微小的層級誤差加總起來就非常可觀。

---

## 參考資料

- FlashInfer `trtllm_allreduce_fusion.cuh` — 融合 all-reduce + RMSNorm Kernel

- FlashInfer `trtllm_moe_allreduce_fusion.cuh` — 融合 MoE finalize + all-reduce + RMSNorm Kernel

- Qwen3.5-MoE 參考實作 — 用於 Qwen/Qwen3.5-397B-A17B 案例研究的 Hugging Face Transformers 參考路徑實作

- 本文最初發表於 Fireworks AI 部落格，你可以在該處找到我們 AI 研究團隊發布的類似技術文章。

## 標籤

LLM, 研究論文, 其他, Kimi, Qwen
