Skip to content

[AMD] Optimize MiniMax-M2.5 - enable fused Triton kernel for FP8 KV cache write in aiter decode path#23620

Merged
HaiShaw merged 1 commit intosgl-project:mainfrom
yctseng0211:reshape_and_cache_flash
Apr 25, 2026
Merged

[AMD] Optimize MiniMax-M2.5 - enable fused Triton kernel for FP8 KV cache write in aiter decode path#23620
HaiShaw merged 1 commit intosgl-project:mainfrom
yctseng0211:reshape_and_cache_flash

Conversation

@yctseng0211
Copy link
Copy Markdown
Collaborator

Motivation

  • On AMD GPUs with FP8 KV cache (--kv-cache-dtype fp8_e4m3) and unified
    attention enabled, the decode KV cache write previously required two
    separate kernel launches: a bf16→fp8 dtype cast (float8_copy_kernel)
    followed by a paged store (store_kvcache).
  • This PR adds a branch in AiterAttnBackend.forward_decode that uses
    launch_reshape_and_cache_flash (an existing Triton kernel already used
    for SWA models) to fuse the cast and store into a single kernel launch.

Modifications

Accuracy Tests

  • GSM8K accuracy: 93.3% (unchanged from baseline).

Speed Tests and Profiling

  • Benchmarked on MI355X with MiniMax-M2.5 FP8 (TP=4, ISL=8192, OSL=1024):
    +2.5% output throughput at conc=64, +2.3% at conc=32, up to +5.9% at
    conc=4. No regression at conc=128 (+0.4%).

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@yctseng0211 yctseng0211 changed the title [AMD] Optimize enable fused Triton kernel for FP8 KV cache write in aiter decode path [AMD] Optimize MiniMax-M2.5 - enable fused Triton kernel for FP8 KV cache write in aiter decode path Apr 24, 2026
@yctseng0211 yctseng0211 marked this pull request as ready for review April 24, 2026 08:22
@HaiShaw HaiShaw merged commit adc5932 into sgl-project:main Apr 25, 2026
57 of 65 checks passed
sogalin added a commit to sogalin/sglang that referenced this pull request Apr 28, 2026
…l only

The fused Triton kernel introduced in PR sgl-project#23620 (commit adc5932) is
correct enough for non-speculative target-model decode (its original
target, MiniMax-M2.5) but its bf16->fp8 implicit cast through tl.store
does not match PyTorch .to(torch.float8_e4m3fn) bit-exactly. PyTorch
casts with round-to-nearest-even + saturation; the Triton path on
ROCm/HIP rounds differently and may not saturate, even when the
per-tensor k_scale / v_scale are 1.0 (verified for Kimi-K2.5 Quark
MXFP4 with kv_cache_dtype=fp8 by direct probe).

Non-speculative inference tolerates this small numerical drift, but
EAGLE3 draft decode reads back its own freshly written K/V cache on
every subsequent draft step, so any drift in the draft cache compounds
across draft steps and collapses the accept length:

  Kimi-K2.5-MXFP4 + EAGLE3 (8xMI300, in/out 8192/1024, conc 4):
    pr-23146 baseline               : accept=3.26  out=675 tok/s
    + seqused_k fix (2bee3c3)     : accept=3.46  out=706 tok/s
    + this commit (target-only gate): accept=3.97  out=807 tok/s
  pr-23461 baseline reference       : accept=3.97  out=798 tok/s

Restrict the fast path to target-model backends by checking
model_runner.is_draft_worker. The SWA path is unchanged (it already
works because SWA models did not exercise the corrupted draft cache).
The Triton kernel itself can be revisited later to match PyTorch fp8
cast semantics; until then, draft model writes route through the
legacy MHATokenToKVPool.set_kv_buffer path.
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants