← 返回首頁

MoonshotAI開源FlashKDA,高性能Kimi Delta Attention核心實現

Kimi.ai
Kimi.ai
@Kimi_Moonshot
1,518🔁 150
𝕏 (Twitter)🔥🔥🔥
AI 中文摘要Claude 生成

MoonshotAI開源FlashKDA,高性能Kimi Delta Attention核心實現,在H20上比flash-linear-attention基準快1.72×–2.22×。

MoonshotAI推出FlashKDA,這是基於CUTLASS的高性能「Kimi Delta Attention (KDA)」核心實現,可作為「flash-linear-attention」的即插即用後端,直接提升prefill階段效能。

效能表現
在H20 GPU上,FlashKDA實現1.72×–2.22×的prefill加速,超越「flash-linear-attention」基準。詳細基準數據見專案的「BENCHMARK_H20.md」文件,提供其在高負載情境下的效能參考。

硬體與軟體需求

  • 需要SM90及以上GPU架構
  • CUDA 12.9及以上版本
  • PyTorch 2.4及以上版本

這些為該核心運作的必要環境條件。

安裝與整合
安裝步驟如下:

git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda
cd flash-kda
git submodule update --init --recursive
pip install -v .

安裝後,即可直接作為「flash-linear-attention」的後端使用,整合細節參見「fla-org/flash-linear-attention#852」PR。無需大幅修改既有程式碼,即能替換後端加速。

核心API:flash_kda.fwd
核心函數為flash_kda.fwd,支援bf16輸入,處理Kimi Delta Attention的前向傳播:

flash_kda.fwd(q, k, v, g, beta, scale, out, A_log, dt_bias, lower_bound,
              initial_state=None, final_state=None, cu_seqlens=None)

關鍵參數包括:

  • qkvgbeta:bf16張量,分別為查詢、金鑰、值、閘門前啟用與beta logits(內部套用sigmoid)
  • scale:float標量,縮放因子
  • out:bf16輸出張量
  • A_log:fp32 log-gate參數,形狀[H]
  • dt_bias:fp32閘門偏差,形狀[H, K]
  • lower_bound:float閘門下界(範圍-5.0至0)
  • initial_state / final_state:可選bf16/fp32張量或None,用於狀態傳遞,形狀[B, H, V, K][N, H, V, K]
  • cu_seqlens:int64累積序列長度,支援變長批次(此時B=1,T為總長)

目前要求K = V = 128,狀態張量dtype需匹配,提供cu_seqlens時視為多序列批次,否則獨立處理每個批次元素。

測試與驗證
執行bash tests/test.sh進行測試:

  • tests/test_fwd.py:正確性測試,與torch參考實現精確匹配,並與「flash-linear-attention」比較,確保數值一致性。

開發支援
為CUDA/C++來源設定IntelliSense (clangd),執行bash setup_clangd.sh,自動產生.clangd文件並安裝全域config.yaml~/.config/clangd/,便利程式開發與除錯。

引用資訊
專案提供BibTeX引用:

@misc{flashkda2026,
      title={FlashKDA: Flash Kimi Delta Attention},
      author={Yutian Chen, Zhiyuan Li, Yucheng Wang, Ming Wei},
      year={2026},
      publisher = {GitHub},
      howpublished = {\url{https://github.com/MoonshotAI/FlashKDA}},
}

GitHub連結:https://github.com/MoonshotAI/FlashKDA