[AMD] Enable DFLASH speculative decoding on ROCm#22342
[AMD] Enable DFLASH speculative decoding on ROCm#22342hnyls2002 merged 3 commits intosgl-project:mainfrom
Conversation
…n backend Add AMD ROCm support for DFLASH speculative decoding by enabling the Triton attention backend as a draft worker backend. Changes: - dflash_worker.py: Add 'triton' to supported_draft_backends, auto-detect ROCm and default to triton (FlashInfer unavailable on HIP) - dflash.py: Fix dummy_q shape in apply_k_rope to match K tensor shape, satisfying the RoPE kernel's num_heads % num_kv_heads == 0 check - triton_backend.py: Guard custom_mask access for DFLASH's non-causal (ENCODER_ONLY) attention which doesn't use custom masks - qwen3.py: Add set_dflash_layers_to_capture for DFLASH aux hidden capture Tested on AMD Instinct MI300X (gfx942), ROCm 7.0.2: - Target: Qwen/Qwen3-8B, Draft: z-lab/Qwen3-8B-DFlash-b16 - attention_backend=triton, TP=1 - CUDA graphs: captured successfully for both target and draft - Inference: correct output with speculative verification Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com>
There was a problem hiding this comment.
Code Review
This pull request enhances DFLASH and Triton backend support by adding safety checks for custom masks, ensuring RoPE kernel compatibility through shape matching, and implementing auxiliary hidden state capture for Qwen3. It also introduces ROCm-aware fallback logic for draft backends in the DFLASH worker. Review feedback recommends consolidating the fallback logic and removing redundant local imports to improve code maintainability.
| # Use triton on ROCm (no FlashInfer), flashinfer on CUDA | ||
| import torch as _torch | ||
| draft_backend = "triton" if _torch.version.hip else "flashinfer" | ||
| elif draft_backend == "trtllm_mha": | ||
| import torch as _torch | ||
| _fb = "triton" if _torch.version.hip else "flashinfer" | ||
| logger.warning( | ||
| "DFLASH draft worker does not support 'trtllm_mha' because the " | ||
| "draft path requires non-causal attention. Falling back to " | ||
| "'flashinfer'." | ||
| "'%s'.", _fb | ||
| ) | ||
| draft_backend = "flashinfer" | ||
| draft_backend = _fb | ||
| elif draft_backend not in supported_draft_backends: | ||
| import torch as _torch | ||
| _fb = "triton" if _torch.version.hip else "flashinfer" | ||
| logger.warning( | ||
| "DFLASH draft worker only supports attention_backend in %s for now, " | ||
| "but got %r. Falling back to 'flashinfer'.", | ||
| "but got %r. Falling back to '%s'.", | ||
| supported_draft_backends, | ||
| draft_backend, | ||
| _fb, | ||
| ) | ||
| draft_backend = "flashinfer" | ||
| draft_backend = _fb |
There was a problem hiding this comment.
The local imports of torch as _torch are redundant because torch is already imported at the top of the file (line 6). Additionally, the logic to determine the fallback backend ('triton' on ROCm, 'flashinfer' otherwise) is repeated multiple times. Consolidating this logic improves maintainability and readability.
# Use triton on ROCm (no FlashInfer), flashinfer on CUDA
fallback_backend = "triton" if torch.version.hip else "flashinfer"
if draft_backend is None:
draft_backend = fallback_backend
elif draft_backend == "trtllm_mha":
logger.warning(
"DFLASH draft worker does not support 'trtllm_mha' because the "
"draft path requires non-causal attention. Falling back to "
"'%s'.", fallback_backend
)
draft_backend = fallback_backend
elif draft_backend not in supported_draft_backends:
logger.warning(
"DFLASH draft worker only supports attention_backend in %s for now, "
"but got %r. Falling back to '%s'.",
supported_draft_backends,
draft_backend,
fallback_backend,
)
draft_backend = fallback_backend|
@dcw02 Can you review this? |
|
/rerun-test test_dflash.py |
|
✅ |
# Conflicts: # python/sglang/srt/models/qwen3.py
Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com> Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com> Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com> Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
Summary
Enables DFLASH speculative decoding on AMD ROCm GPUs by adding Triton attention backend support to the draft worker. DFLASH currently only supports FlashInfer/FA3/FA4 backends, which are unavailable on ROCm.
Changes (4 files, +33/-10 lines)
dflash_worker.pytritonto supported backends, auto-detect ROCmdflash.pyapply_k_ropefor RoPE kernel compatibilitytriton_backend.pycustom_maskaccess for non-causal ENCODER_ONLY attentionqwen3.pyset_dflash_layers_to_capturefor DFLASH hidden captureDetails
dflash_worker.py: The draft worker hardcodes
("flashinfer", "fa3", "fa4")as supported backends. This adds"triton"and auto-detects ROCm (torch.version.hip) to default to triton instead of flashinfer.dflash.py:
apply_k_ropecreates a 1-head dummy Q tensor, but ROCm'ssgl_kernel.rotary_embeddingrequiresnum_heads % num_kv_heads == 0. Fix: match dummy_q shape to K shape.triton_backend.py: DFLASH uses non-causal
ENCODER_ONLYattention (no custom mask). During CUDA graph capture, the triton backend unconditionally accessesspec_info.custom_mask.shape[0], crashing when mask is None. This guards the access.qwen3.py: DFLASH requires
set_dflash_layers_to_captureon the target model. Currently only implemented onLlamaForCausalLM. This adds it toQwen3ForCausalLM(same pattern asset_eagle3_layers_to_capture).Test Results
Hardware: AMD Instinct MI300X (gfx942), ROCm 7.0.2
Image:
lmsysorg/sglang-rocm:v0.5.10rc0-rocm700-mi30x-20260407+pip install sglang@HEADaccept_length_per_req=[0]logged ✅Notes
triton_backend.pyfix (custom_mask guard) benefits all speculative modes that use non-causal attention, not just DFLASH.dflash.pyRoPE fix is also needed on CUDA when using sgl_kernel — it's not ROCm-specific.qwen3.pychange should ideally be added to all model classes that want DFLASH support (Llama already has it).