[Speculative][ROCm/MI355X] Enable DFlash speculative decoding on ROCm/MI355X#23388
Draft
zhentaocc wants to merge 10 commits intosgl-project:mainfrom
Draft
[Speculative][ROCm/MI355X] Enable DFlash speculative decoding on ROCm/MI355X#23388zhentaocc wants to merge 10 commits intosgl-project:mainfrom
zhentaocc wants to merge 10 commits intosgl-project:mainfrom
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
28ef43a to
5076c16
Compare
Contributor
Author
Baselineconc=16 |
Collaborator
|
@zhentaocc why do we see perf drop with DFlash for larger concurrency compared to the baseline? |
Contributor
Author
draft model is not optimized, main overhead on attention+gemm. One more finding is enabling dflash seems to disable cudagraph which contribute to host bound with large concurrency. |
DFlash was previously CUDA-only due to three issues: 1. Draft worker attention backends limited to flashinfer/fa3/fa4 (all CUDA-only) 2. Fused KV materialization gated by is_cuda() 3. Sampling verify ops import gated by is_cuda() Changes: - Add "aiter" to supported_draft_backends (triton was already added) - Enable fused KV materialization on ROCm via is_cuda() or is_hip() (the underlying Triton kernel is platform-agnostic) - Allow sampling verify ops import on ROCm (torch.cuda.is_available() covers HIP; try/except safely falls back if ops are absent) Tested on 8x AMD Instinct MI355X (gfx950) with ROCm 7.2: - Model: Qwen3-8B + z-lab/Qwen3-8B-DFlash-b16 draft - Accuracy: greedy output identical to baseline (temperature=0) - Performance: 57.5 → 92.9 tok/s (1.62x speedup) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The previous change broadened the import guard from is_cuda() to torch.cuda.is_available(), which is True on ROCm/HIP. The sgl_kernel Python wrappers import successfully but the underlying C++ ops (tree_speculative_sampling_target_only, top_p_renorm_probs) are not registered in the ROCm build (speculative_sampling.cu is not compiled). This caused a runtime AttributeError when non-greedy sampling was used with DFlash on ROCm. Fix: After importing the wrappers, verify the C++ ops are actually registered via hasattr(torch.ops.sgl_kernel, ...) before setting _DFLASH_SAMPLING_VERIFY_AVAILABLE = True. Tested: DFlash with temperature=0.5/0.8/1.0 + top_p=0.9 on MI355X now correctly falls back to greedy verification without crashing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The sgl-kernel ROCm build does not compile speculative_sampling.cu, so the C++ ops (tree_speculative_sampling_target_only, top_k_renorm_probs, top_p_renorm_probs) are unavailable. Previously, DFlash fell back to greedy-only verification on ROCm, losing sampling diversity. This commit adds pure PyTorch fallbacks: - _top_k_renorm_prob_torch: top-k probability renormalization - _top_p_renorm_prob_torch: nucleus (top-p) probability renormalization - _dflash_chain_sampling_verify_torch: chain-structured speculative sampling verification (DFlash is topk=1, linear chain) When C++ ops are available (CUDA), the kernel path is used unchanged. When absent (ROCm), the PyTorch fallback enables non-greedy sampling verification without requiring sgl-kernel recompilation. Tested on MI355X with DFlash + Qwen3-8B: - temp=0.5/0.8/1.0 + top_p=0.9 + top_k=20: diverse outputs, no crash - Greedy accuracy unchanged - Performance: 312.9 tok/s (CUDA graph enabled) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Port the tree_speculative_sampling_target_only CUDA kernel to ROCm using aiter's hipCUB-based sampling utilities (SamplingTempStorage, DeviceSamplingFromProb, vec_t) as drop-in replacements for FlashInfer's CUB-based equivalents. New files: - csrc/speculative/speculative_sampling_rocm.cuh: HIP kernel header using aiter::sampling namespace instead of flashinfer::sampling - csrc/speculative/speculative_sampling.hip: ROCm entry point Build changes: - setup_rocm.py: Add speculative_sampling.hip to sources, auto-detect aiter sampling include path - common_extension_rocm.cc: Register tree_speculative_sampling_target_only Runtime changes: - dflash_utils.py: Use kernel tree_speculative_sampling when available, with PyTorch fallback for top_k/top_p renorm (those ops are still from FlashInfer and not yet ported) Tested on MI355X (gfx950, ROCm 7.2): - tree_speculative_sampling_target_only kernel: works correctly - Non-greedy sampling (temp=0.5/0.8/1.0 + top_p + top_k): diverse outputs - Greedy accuracy: unchanged - Performance: 312 tok/s (with CUDA graph) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When available, use aiter's HIP-native TopKRenormProbKernel instead of the PyTorch fallback for top-k probability renormalization during DFlash non-greedy sampling verification. This completes the kernel-level solution for DFlash on ROCm: - tree_speculative_sampling_target_only: HIP kernel (ported) - top_k_renorm: aiter HIP kernel (reused) - top_p_renorm: PyTorch fallback (no aiter kernel available) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ling kernels for DFlash ROCm Integrate the new AITER sampling kernels into DFlash speculative decoding on ROCm: 1. top_p_renorm_probs: replaces PyTorch sort+cumsum+scatter fallback (1.5-2.9x speedup at bs>=4) 2. chain_speculative_sampling: replaces Python loop with .item() calls (7-112x speedup, eliminates GPU->CPU round-trips) Both kernels fall back to pure PyTorch if AITER is not installed. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
…ampling toggle 1. aiter_backend.py: Guard spec_info.custom_mask with None check in 3 CUDA graph capture paths. DFlash verify may skip custom_mask when the attention backend handles masking internally, causing AttributeError during graph capture. 2. dflash_utils.py: Add SGLANG_DISABLE_AITER_SAMPLING env var to toggle AITER sampling kernels off for A/B benchmarking. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Profile of DFlash on Qwen3-8B (MI355X, conc=64) showed Q/K RMSNorm runs as two separate kernels on ROCm (~12.9ms total per batch in the draft model alone), because apply_qk_norm only had a fast path for CUDA. Add an AITER fused fast path that runs Q-norm + K-norm in a single HIP kernel: - Imports aiter.ops.fused_qk_norm_rope_cache_quant.fused_qk_rmsnorm when running on ROCm with AITER installed (gracefully falls back if the module is missing) - Gated on _is_hip + matching epsilons + non-deterministic mode - Uses reshape (not view) because qkv.split() returns strided views and ROCm RMSNorm kernels fault on strided inputs (see sgl-project#23159) - Returns new tensors (AITER kernel is functional, not in-place) Benefits both the DFlash draft model (5 layers x N requests per decode step) and the target Qwen3 model on ROCm. CUDA path is untouched. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The AITER fused_qk_rmsnorm kernel assumes Q and K have the same number
of rows; for GQA models (num_heads_q > num_heads_kv) the kernel returns
the wrong output shape and the caller crashes with:
RuntimeError: shape '[N, kv_size]' is invalid for input of size <q_size>
Add a q.shape[-1] == k.shape[-1] guard so we only take the AITER fast
path for MHA models. GQA models (Qwen3-8B, etc.) fall through to the
existing PyTorch path.
Verified by launching Qwen3-8B baseline + DFlash on MI355X — both load
and run successfully where the previous code crashed during CUDA graph
capture.
Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
f421e1b to
5ab3cd2
Compare
…ntion The DFlash draft model was running 3-4 separate kernels per layer for the attention pre-attention pipeline on ROCm: 1. qkv split (view, free) 2. Q-norm + K-norm RMSNorms (one fused if MHA, two separate if GQA) 3. RoPE (one rotary kernel) Mirror the optimization Qwen3 uses on ROCm via forward_prepare_aiter_fused_mrope: call AITER's fused HIP kernel that combines split + Q-RMSNorm + K-RMSNorm + RoPE into a single launch. Differences from Qwen3: - Use the non-MRoPE sibling kernel (fused_qk_norm_rope_cache_pts_quant_shuffle from aiter/ops/fused_qk_norm_rope_cache_quant.py) since DFlash uses regular Neox-style RoPE. - Set return_kv=True so the kernel writes rotated K/V into supplied output buffers instead of a paged KV cache. This preserves DFlash's ENCODER_ONLY local-block semantics: K/V are consumed immediately by RadixAttention rather than persisted via slot_mapping. Gating: only enabled when on ROCm + AITER kernel is importable + RoPE is not MRoPE + RoPE is Neox-style. CUDA users hit the unchanged eager path. Critical for the GQA case: unlike apply_qk_norm's _aiter_fused_qk_rmsnorm fast path (which assumes Q/K have the same row count and is gated off for GQA models), this fused-rope kernel takes num_heads_q != num_heads_k as explicit parameters, so it works correctly for Qwen3-8B (32q/8kv heads). Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
9c073cc to
9057b68
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
DFlash ROCm Enablement (python/sglang)
aiterto DFlash draft workersupported_draft_backendstop_p_renorm_probskernel — replaces PyTorch sort+cumsum+scatter fallback (1.9-2.6x kernel speedup at bs>=4)chain_speculative_samplingkernel — replaces Python-loop.item()fallback (7-114x kernel speedup)SGLANG_DISABLE_AITER_SAMPLING=1Speculative Sampling Kernel Port (sgl-kernel)
tree_speculative_sampling_target_onlyCUDA kernel to ROCm/HIP using aiter hipCUB-based sampling utilities as drop-in replacements for FlashInfer CUB-based equivalentsspeculative_sampling_rocm.cuh,speculative_sampling.hipcommon_extension_rocm.cc, add aiter include path tosetup_rocm.pyNew AITER Kernels (upstream PRs to ROCm/aiter)
Three new kernels added to AITER's sampling module for DFlash ROCm support:
top_p_renorm_probschain_speculative_samplingtop_k_top_p_renorm_probsaiter_backend.py Fix
spec_info.custom_maskwith None check in 3 CUDA graph capture paths to support DFlash verify with aiter attention backendTest Results (AMD Instinct MI355X, gfx950, ROCm 7.2)
Model: Qwen3-8B + z-lab/Qwen3-8B-DFlash-b16 (block_size=16, 5 draft layers)
GSM8K Accuracy (lm_eval, 5-shot, 1319 questions)
Accuracy is identical between baseline and DFlash.
Overall Throughput (1K input / 1K output, request-rate=inf)
Analysis
AITER Kernel Micro-benchmarks (MI355X)
top_p_renorm_probs(V=32K)chain_speculative_sampling(V=32K, draft=4)top_k_top_p_renorm_probs(V=32K)Test plan
Generated with Claude Code