← 返回首頁

深度層通訊從累加升級為檢索,MoDA透過硬體優化實現高效深度注意力

Lianghui Zhu
Lianghui Zhu
@lianghui_zhu
1,159🔁 134
𝕏 (Twitter)🔥🔥🔥🔥

AI 語音朗讀 · Edge TTS

AI 中文摘要Claude 生成

深度層通訊從累加升級為檢索,MoDA透過硬體優化實現高效深度注意力。

過去十年,人工智慧模型架構的第一階段專注擴大層內運算規模,但層間通訊機制幾乎停滯於2015年ResNet的「x + F(x)」殘差連接,導致訊號稀釋與許多層「學會沉默」。作者Lianghui Zhu主張進入第二階段:將層間通訊從累加轉為檢索,並推出「Flash Depth Attention (FDA)」與「Mixture-of-Depths Attention (MoDA)」,使深度擴展真正高效。

深度擴展的瓶頸

模型在紙面上看來層數眾多(如152層),但淺層形成的關鍵特徵在反覆殘差更新中逐漸稀釋,深層難以恢復原始訊號。許多層選擇「學會沉默」,僅貢獻少量新資訊以避免掩埋前層內容,導致網路名義上深但實質淺薄。

  • 瓶頸不在層內運算,而在層間通訊:如同CPU早年遇記憶體頻寬限制,深度學習需升級「互聯」機制。
  • 比喻為「傳話遊戲」:殘差連接讓每人重複累積先前訊息,但到第152人時,原始訊息淹沒在152聲合唱中,第152層無法輕易「聽清」第3層所說。

既有方案的類別錯誤

先前嘗試如「DenseNet」(CVPR 2017最佳論文)、「DenseFormer」、「Hyper-Connections」、「MUDDFormer」皆試圖改善層輸出「混合」:更好係數、更多通道、自適應權重,但皆維持累加框架,無法讓深層直接存取特定淺層內容。

  • 這些方法假設層間通訊為「累加」(以學習或生成係數組合訊號),忽略檢索本質:查詢端(「我需要什麼」)與鍵端(「我有什麼」)雙方均有發聲權。
  • 作者批判這是「類別錯誤」:如同預測有用層而非直接檢查內容,無法解決訊號稀釋根本問題。

從深度注意力到Flash Depth Attention (FDA)

層間通訊應為「檢索」:深層直接「拍肩」詢問淺層「你說了什麼?」。但朴素深度注意力實作,前向+反向傳遞需44,924 ms,極度緩慢。

  • FDA為硬體高效核心,加速深度注意力逾40,000倍,实现全表達力深度檢索,可大規模訓練。
  • 傳統Transformer流程:殘差連接 → 序列注意力 → 殘差連接 → FFN。
  • FDA流程:深度注意力 → 序列注意力 → 深度注意力 → FFN。

Mixture-of-Depths Attention (MoDA)的統一檢索

MoDA進一步融合深度與序列檢索至單一softmax,每注意力頭同時關注當層序列KV對及所有前層深度KV對,一操作實現兩維度檢索。

  • 解決非連續記憶體存取,達成64K序列長度下FlashAttention-2效率的97.3%。
  • 論文於2026年3月16日發佈於arXiv:2603.15619,部落格《The Second Half of Model Architecture》於2026年4月11日上線,程式碼於GitHub hustvl/MoDA開源。
  • 視覺化顯示:注意力熱圖中,深度KV區塊持續獲大量注意力質量,證明模型積極使用跨層檢索,attention-sink現象消失。

實證成果與基準比較

在1.5B參數模型上,MoDA全面超越OLMo2基準,僅3.7% FLOPs額外開銷。結合post-norm優於pre-norm,證明其為深度擴展的有力原語。

  • 下游任務平均性能提升(400B token訓練)

    模型 PIQA HellaSwag WinoGrande OpenBookQA BoolQ SciQ ARC-E ARC-C COPA MMLU 平均
    OLMo2-700M 73.72 58.77 55.33 35.60 56.24 89.50 66.84 33.44 77.00 24.69 57.11
    MoDA-700M 73.39 59.19 60.22 37.20 59.33 89.60 67.37 34.78 82.00 25.61 58.87
    OLMo2-1.5B 76.55 65.86 63.22 38.80 63.61 90.60 72.98 42.47 81.00 27.73 62.28
    MoDA-1.5B 76.82 66.24 65.59 41.60 67.34 92.10 72.81 46.82 85.00 29.59 64.39

    MoDA-1.5B平均提升2.11%。

  • 驗證困惑度平均改善(越低越好)

    模型 C4 ICE m2d2-s2orc Pile Wiki-text Books CC peS2o Reddit Stack 平均
    OLMo2-700M 18.32 17.43 24.37 9.53 12.26 16.78 20.53 9.17 23.84 3.93 15.61
    MoDA-700M 18.29 17.24 23.64 9.48 12.06 16.58 20.52 9.14 23.75 3.90 15.46
    OLMo2-1.5B 16.16 15.37 21.10 8.45 10.41 14.19 18.13 8.19 21.21 3.57 13.67
    MoDA-1.5B 15.97 15.08 20.92 8.33 10.16 13.95 17.88 8.09 20.85 3.52 13.47

    10個驗證基準平均困惑度降0.2。

硬體效率數據(A100, bf16, 前向+反向, B=1, d=64, C=64)

MoDA-Triton核心在序列長度擴大時額外時間持續下降,證明可擴展性。

  • 序列長度T擴展(G=8, Hq=64, Hk=8, L=64)

    T FA2-Triton (ms) MoDA-Triton (ms) Depth Utilization Extra Time
    4096 7.970 10.750 12.50% 25.86%
    8192 28.700 35.427 12.50% 18.99%
    16384 116.700 127.661 12.50% 8.59%
    32768 459.854 480.914 12.50% 4.38%
    65536 1831.668 1883.026 12.50% 2.73%
  • GQA群組大小G擴展(T=16384, Hk=8, L=64)

    G Hq FA2-Triton (ms) MoDA-Triton (ms) Depth Utilization Extra Time
    2 16 28.982 39.741 3.12% 27.07%
    4 32 58.071 68.939 6.25% 15.76%
    8 64 116.700 127.661 12.50% 8.59%
    16 128 233.700 244.900 25.00% 4.57%
    32 256 467.107 480.767 50.00% 2.84%
  • 模型深度L擴展(T=16384, G=8, Hq=64, Hk=8)

    L FA2-Triton (ms) MoDA-Triton (ms) Depth Utilization Extra Time
    64 116.700 127.661 12.50% 8.59%
    128 116.700 138.224 12.50% 15.57%
    256 116.700 167.958 12.50% 30.52%

架構設計的反思與未來

第一階段聚焦「擴大組件」(序列長度、資料、參數),第二階段轉向「擴大通訊」:從序列維度(FlashAttention等)擴至深度維度。作者預見此原則泛化至模態間、時序步等靜態通道,取代累加為檢索。

  • 傳話遊戲升級:不再透過合唱雜音,直接對話取代中介。
  • 獨立驗證如Google的DCA、Huawei的MRLA等,證概念正確,但FDA/MoDA首度解決工程障礙。
  • 程式庫支援PyTorch >=2.5、Triton >=3.0等,已釋出Triton核心(fda_v12.py、moda_v14.py等),視覺任務如ImageNet分類訓練腳本可用。
  • 作者呼籲:殘差「+」運作輝煌十年,現在該升級樓梯,歡迎第二階段。

論文作者包括Lianghui Zhu、Yuxin Fang、Bencheng Liao等,來自華中科技大學與ByteDance Seed。部落格強烈推薦先讀,程式碼涵蓋OLMo2、DeiT、Flash Linear Attention基礎。