Skip to content

[Speculative][ROCm/MI355X] Enable DFlash speculative decoding on ROCm/MI355X#23388

Draft
zhentaocc wants to merge 10 commits intosgl-project:mainfrom
zhentaocc:dflash-rocm-support
Draft

[Speculative][ROCm/MI355X] Enable DFlash speculative decoding on ROCm/MI355X#23388
zhentaocc wants to merge 10 commits intosgl-project:mainfrom
zhentaocc:dflash-rocm-support

Conversation

@zhentaocc
Copy link
Copy Markdown
Contributor

@zhentaocc zhentaocc commented Apr 21, 2026

Summary

DFlash ROCm Enablement (python/sglang)

  • Add aiter to DFlash draft worker supported_draft_backends
  • Enable fused KV materialization on ROCm (Triton kernel, platform-agnostic)
  • Add PyTorch fallback for top_k/top_p renorm + runtime C++ op detection
  • Integrate AITER top_p_renorm_probs kernel — replaces PyTorch sort+cumsum+scatter fallback (1.9-2.6x kernel speedup at bs>=4)
  • Integrate AITER chain_speculative_sampling kernel — replaces Python-loop .item() fallback (7-114x kernel speedup)
  • Graceful fallback: all AITER sampling kernels are optional, disabled via SGLANG_DISABLE_AITER_SAMPLING=1

Speculative Sampling Kernel Port (sgl-kernel)

  • Port tree_speculative_sampling_target_only CUDA kernel to ROCm/HIP using aiter hipCUB-based sampling utilities as drop-in replacements for FlashInfer CUB-based equivalents
  • New files: speculative_sampling_rocm.cuh, speculative_sampling.hip
  • Register op in common_extension_rocm.cc, add aiter include path to setup_rocm.py

New AITER Kernels (upstream PRs to ROCm/aiter)

Three new kernels added to AITER's sampling module for DFlash ROCm support:

Kernel Description Tests Kernel Speedup vs PyTorch
top_p_renorm_probs Nucleus probability renormalization (binary search) 48/48 1.9-2.6x (bs>=4)
chain_speculative_sampling DFlash chain verification + bonus token sampling 36/36 7-114x
top_k_top_p_renorm_probs Fused top-k + top-p renorm in single kernel 54/54 1.8-2.0x vs sequential

aiter_backend.py Fix

  • Guard spec_info.custom_mask with None check in 3 CUDA graph capture paths to support DFlash verify with aiter attention backend

Test Results (AMD Instinct MI355X, gfx950, ROCm 7.2)

Model: Qwen3-8B + z-lab/Qwen3-8B-DFlash-b16 (block_size=16, 5 draft layers)

GSM8K Accuracy (lm_eval, 5-shot, 1319 questions)

Baseline DFlash
exact_match (flexible) 0.8961 0.8961
exact_match (strict) 0.8961 0.8969

Accuracy is identical between baseline and DFlash.

Overall Throughput (1K input / 1K output, request-rate=inf)

Conc Baseline Output (tok/s) DFlash Output (tok/s) Delta Baseline TPOT (ms) DFlash TPOT (ms) Delta Baseline TTFT (ms) DFlash TTFT (ms) Delta Accept Len
4 749 1,327 +77.2% 5.04 2.77 -45.0% 28.6 29.8 +4.2% 3.38
64 6,034 4,268 -29.3% 9.79 10.81 +10.4% 113.6 1,790.8 - 3.72
256 10,771 4,216 -60.9% 22.72 11.12 -51.1% 197.0 23,243.2 - 3.77

Analysis

  • Low concurrency (conc=4): DFlash delivers +77% output throughput and -45% TPOT. This is the sweet spot — speculative decoding generates ~3.4 tokens per step, and the draft model overhead is amortized across fewer concurrent requests.
  • Medium concurrency (conc=64): DFlash throughput drops 29% vs baseline. The draft model forward passes compete with target model for GPU compute, reducing throughput. TPOT is comparable.
  • High concurrency (conc=256): DFlash throughput drops 61% vs baseline, but TPOT improves 51%. The throughput loss is because DFlash's max running requests is capped at 48 (vs baseline unlimited), creating a queue bottleneck that inflates TTFT. The per-token latency benefit remains strong.
  • Accept length is consistently 3.4-3.8 across all concurrency levels.

AITER Kernel Micro-benchmarks (MI355X)

Kernel bs=1 bs=4 bs=16
top_p_renorm_probs (V=32K) 0.82x 1.91x 2.58x
chain_speculative_sampling (V=32K, draft=4) 7.68x 29.15x 112.72x
top_k_top_p_renorm_probs (V=32K) 1.90x 1.91x 1.96x

Test plan

  • GSM8K accuracy: baseline=0.8961, DFlash=0.8961 (identical)
  • Greedy determinism (3 identical runs)
  • Non-greedy sampling works (temp=0.8, top_p=0.9, top_k=20)
  • Overall throughput at conc=4 (+77%), conc=64, conc=256
  • CUDA graph enabled
  • tree_speculative_sampling_target_only kernel on ROCm
  • AITER kernel unit tests: 138/138 passed
  • aiter_backend.py custom_mask None guard
  • CI tests on CUDA (no regression)

Generated with Claude Code

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zhentaocc zhentaocc marked this pull request as draft April 21, 2026 16:39
@zhentaocc zhentaocc changed the title [Speculative] Enable DFlash speculative decoding on ROCm/MI355X [Speculative][ROCm/MI355X] Enable DFlash speculative decoding on ROCm/MI355X Apr 22, 2026
@zhentaocc zhentaocc force-pushed the dflash-rocm-support branch from 28ef43a to 5076c16 Compare April 24, 2026 04:08
@zhentaocc
Copy link
Copy Markdown
Contributor Author

Baseline

conc=16


============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 16        
Successful requests:                     160       
Benchmark duration (s):                  56.86     
Total input tokens:                      81570     
Total input text tokens:                 81570     
Total generated tokens:                  80375     
Total generated tokens (retokenized):    80372     
Request throughput (req/s):              2.81      
Input token throughput (tok/s):          1434.59   
Output token throughput (tok/s):         1413.58   
Peak output token throughput (tok/s):    2892.00   
Peak concurrent requests:                25        
Total token throughput (tok/s):          2848.17   
Concurrency:                             15.25     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   5420.02   
Median E2E Latency (ms):                 3566.39   
P90 E2E Latency (ms):                    8990.48   
P99 E2E Latency (ms):                    27386.11  
---------------Time to First Token----------------
Mean TTFT (ms):                          2287.44   
Median TTFT (ms):                        28.91     
P99 TTFT (ms):                           22384.61  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.23      
Median TPOT (ms):                        5.83      
P99 TPOT (ms):                           8.57      
---------------Inter-Token Latency----------------
Mean ITL (ms):                           6.25      
Median ITL (ms):                         5.35      
P95 ITL (ms):                            6.87      
P99 ITL (ms):                            17.29     
Max ITL (ms):                            598.94    
==================================================

@hubertlu-tw
Copy link
Copy Markdown
Collaborator

@zhentaocc why do we see perf drop with DFlash for larger concurrency compared to the baseline?

@zhentaocc
Copy link
Copy Markdown
Contributor Author

zhentaocc commented May 9, 2026

@zhentaocc why do we see perf drop with DFlash for larger concurrency compared to the baseline?

draft model is not optimized, main overhead on attention+gemm. One more finding is enabling dflash seems to disable cudagraph which contribute to host bound with large concurrency.

Chen, Todd and others added 9 commits May 9, 2026 05:05
DFlash was previously CUDA-only due to three issues:
1. Draft worker attention backends limited to flashinfer/fa3/fa4 (all CUDA-only)
2. Fused KV materialization gated by is_cuda()
3. Sampling verify ops import gated by is_cuda()

Changes:
- Add "aiter" to supported_draft_backends (triton was already added)
- Enable fused KV materialization on ROCm via is_cuda() or is_hip()
  (the underlying Triton kernel is platform-agnostic)
- Allow sampling verify ops import on ROCm (torch.cuda.is_available()
  covers HIP; try/except safely falls back if ops are absent)

Tested on 8x AMD Instinct MI355X (gfx950) with ROCm 7.2:
- Model: Qwen3-8B + z-lab/Qwen3-8B-DFlash-b16 draft
- Accuracy: greedy output identical to baseline (temperature=0)
- Performance: 57.5 → 92.9 tok/s (1.62x speedup)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous change broadened the import guard from is_cuda() to
torch.cuda.is_available(), which is True on ROCm/HIP. The sgl_kernel
Python wrappers import successfully but the underlying C++ ops
(tree_speculative_sampling_target_only, top_p_renorm_probs) are not
registered in the ROCm build (speculative_sampling.cu is not compiled).

This caused a runtime AttributeError when non-greedy sampling was used
with DFlash on ROCm.

Fix: After importing the wrappers, verify the C++ ops are actually
registered via hasattr(torch.ops.sgl_kernel, ...) before setting
_DFLASH_SAMPLING_VERIFY_AVAILABLE = True.

Tested: DFlash with temperature=0.5/0.8/1.0 + top_p=0.9 on MI355X
now correctly falls back to greedy verification without crashing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The sgl-kernel ROCm build does not compile speculative_sampling.cu, so
the C++ ops (tree_speculative_sampling_target_only, top_k_renorm_probs,
top_p_renorm_probs) are unavailable. Previously, DFlash fell back to
greedy-only verification on ROCm, losing sampling diversity.

This commit adds pure PyTorch fallbacks:
- _top_k_renorm_prob_torch: top-k probability renormalization
- _top_p_renorm_prob_torch: nucleus (top-p) probability renormalization
- _dflash_chain_sampling_verify_torch: chain-structured speculative
  sampling verification (DFlash is topk=1, linear chain)

When C++ ops are available (CUDA), the kernel path is used unchanged.
When absent (ROCm), the PyTorch fallback enables non-greedy sampling
verification without requiring sgl-kernel recompilation.

Tested on MI355X with DFlash + Qwen3-8B:
- temp=0.5/0.8/1.0 + top_p=0.9 + top_k=20: diverse outputs, no crash
- Greedy accuracy unchanged
- Performance: 312.9 tok/s (CUDA graph enabled)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Port the tree_speculative_sampling_target_only CUDA kernel to ROCm
using aiter's hipCUB-based sampling utilities (SamplingTempStorage,
DeviceSamplingFromProb, vec_t) as drop-in replacements for FlashInfer's
CUB-based equivalents.

New files:
- csrc/speculative/speculative_sampling_rocm.cuh: HIP kernel header
  using aiter::sampling namespace instead of flashinfer::sampling
- csrc/speculative/speculative_sampling.hip: ROCm entry point

Build changes:
- setup_rocm.py: Add speculative_sampling.hip to sources, auto-detect
  aiter sampling include path
- common_extension_rocm.cc: Register tree_speculative_sampling_target_only

Runtime changes:
- dflash_utils.py: Use kernel tree_speculative_sampling when available,
  with PyTorch fallback for top_k/top_p renorm (those ops are still
  from FlashInfer and not yet ported)

Tested on MI355X (gfx950, ROCm 7.2):
- tree_speculative_sampling_target_only kernel: works correctly
- Non-greedy sampling (temp=0.5/0.8/1.0 + top_p + top_k): diverse outputs
- Greedy accuracy: unchanged
- Performance: 312 tok/s (with CUDA graph)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When available, use aiter's HIP-native TopKRenormProbKernel instead of
the PyTorch fallback for top-k probability renormalization during
DFlash non-greedy sampling verification.

This completes the kernel-level solution for DFlash on ROCm:
- tree_speculative_sampling_target_only: HIP kernel (ported)
- top_k_renorm: aiter HIP kernel (reused)
- top_p_renorm: PyTorch fallback (no aiter kernel available)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ling kernels for DFlash ROCm

Integrate the new AITER sampling kernels into DFlash speculative decoding
on ROCm:

1. top_p_renorm_probs: replaces PyTorch sort+cumsum+scatter fallback
   (1.5-2.9x speedup at bs>=4)
2. chain_speculative_sampling: replaces Python loop with .item() calls
   (7-112x speedup, eliminates GPU->CPU round-trips)

Both kernels fall back to pure PyTorch if AITER is not installed.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ampling toggle

1. aiter_backend.py: Guard spec_info.custom_mask with None check in 3
   CUDA graph capture paths. DFlash verify may skip custom_mask when the
   attention backend handles masking internally, causing AttributeError
   during graph capture.

2. dflash_utils.py: Add SGLANG_DISABLE_AITER_SAMPLING env var to toggle
   AITER sampling kernels off for A/B benchmarking.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Profile of DFlash on Qwen3-8B (MI355X, conc=64) showed Q/K RMSNorm runs
as two separate kernels on ROCm (~12.9ms total per batch in the draft
model alone), because apply_qk_norm only had a fast path for CUDA.

Add an AITER fused fast path that runs Q-norm + K-norm in a single HIP
kernel:
- Imports aiter.ops.fused_qk_norm_rope_cache_quant.fused_qk_rmsnorm
  when running on ROCm with AITER installed (gracefully falls back if
  the module is missing)
- Gated on _is_hip + matching epsilons + non-deterministic mode
- Uses reshape (not view) because qkv.split() returns strided views and
  ROCm RMSNorm kernels fault on strided inputs (see sgl-project#23159)
- Returns new tensors (AITER kernel is functional, not in-place)

Benefits both the DFlash draft model (5 layers x N requests per decode
step) and the target Qwen3 model on ROCm. CUDA path is untouched.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The AITER fused_qk_rmsnorm kernel assumes Q and K have the same number
of rows; for GQA models (num_heads_q > num_heads_kv) the kernel returns
the wrong output shape and the caller crashes with:

    RuntimeError: shape '[N, kv_size]' is invalid for input of size <q_size>

Add a q.shape[-1] == k.shape[-1] guard so we only take the AITER fast
path for MHA models. GQA models (Qwen3-8B, etc.) fall through to the
existing PyTorch path.

Verified by launching Qwen3-8B baseline + DFlash on MI355X — both load
and run successfully where the previous code crashed during CUDA graph
capture.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@zhentaocc zhentaocc force-pushed the dflash-rocm-support branch from f421e1b to 5ab3cd2 Compare May 9, 2026 10:09
…ntion

The DFlash draft model was running 3-4 separate kernels per layer for the
attention pre-attention pipeline on ROCm:
  1. qkv split (view, free)
  2. Q-norm + K-norm RMSNorms (one fused if MHA, two separate if GQA)
  3. RoPE (one rotary kernel)

Mirror the optimization Qwen3 uses on ROCm via
forward_prepare_aiter_fused_mrope: call AITER's fused HIP kernel that
combines split + Q-RMSNorm + K-RMSNorm + RoPE into a single launch.

Differences from Qwen3:
- Use the non-MRoPE sibling kernel
  (fused_qk_norm_rope_cache_pts_quant_shuffle from
   aiter/ops/fused_qk_norm_rope_cache_quant.py) since DFlash uses regular
  Neox-style RoPE.
- Set return_kv=True so the kernel writes rotated K/V into supplied output
  buffers instead of a paged KV cache. This preserves DFlash's ENCODER_ONLY
  local-block semantics: K/V are consumed immediately by RadixAttention
  rather than persisted via slot_mapping.

Gating: only enabled when on ROCm + AITER kernel is importable + RoPE is
not MRoPE + RoPE is Neox-style. CUDA users hit the unchanged eager path.

Critical for the GQA case: unlike apply_qk_norm's _aiter_fused_qk_rmsnorm
fast path (which assumes Q/K have the same row count and is gated off for
GQA models), this fused-rope kernel takes num_heads_q != num_heads_k as
explicit parameters, so it works correctly for Qwen3-8B (32q/8kv heads).

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@zhentaocc zhentaocc force-pushed the dflash-rocm-support branch from 9c073cc to 9057b68 Compare May 9, 2026 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants