Skip to content

[AMD] Enable DFLASH speculative decoding on ROCm#22342

Merged
hnyls2002 merged 3 commits intosgl-project:mainfrom
andyluo7:add-dflash-rocm-triton-support
Apr 17, 2026
Merged

[AMD] Enable DFLASH speculative decoding on ROCm#22342
hnyls2002 merged 3 commits intosgl-project:mainfrom
andyluo7:add-dflash-rocm-triton-support

Conversation

@andyluo7
Copy link
Copy Markdown
Contributor

@andyluo7 andyluo7 commented Apr 8, 2026

Summary

Enables DFLASH speculative decoding on AMD ROCm GPUs by adding Triton attention backend support to the draft worker. DFLASH currently only supports FlashInfer/FA3/FA4 backends, which are unavailable on ROCm.

Changes (4 files, +33/-10 lines)

File Change
dflash_worker.py Add triton to supported backends, auto-detect ROCm
dflash.py Fix dummy_q shape in apply_k_rope for RoPE kernel compatibility
triton_backend.py Guard custom_mask access for non-causal ENCODER_ONLY attention
qwen3.py Add set_dflash_layers_to_capture for DFLASH hidden capture

Details

dflash_worker.py: The draft worker hardcodes ("flashinfer", "fa3", "fa4") as supported backends. This adds "triton" and auto-detects ROCm (torch.version.hip) to default to triton instead of flashinfer.

dflash.py: apply_k_rope creates a 1-head dummy Q tensor, but ROCm's sgl_kernel.rotary_embedding requires num_heads % num_kv_heads == 0. Fix: match dummy_q shape to K shape.

triton_backend.py: DFLASH uses non-causal ENCODER_ONLY attention (no custom mask). During CUDA graph capture, the triton backend unconditionally accesses spec_info.custom_mask.shape[0], crashing when mask is None. This guards the access.

qwen3.py: DFLASH requires set_dflash_layers_to_capture on the target model. Currently only implemented on LlamaForCausalLM. This adds it to Qwen3ForCausalLM (same pattern as set_eagle3_layers_to_capture).

Test Results

Hardware: AMD Instinct MI300X (gfx942), ROCm 7.0.2
Image: lmsysorg/sglang-rocm:v0.5.10rc0-rocm700-mi30x-20260407 + pip install sglang@HEAD

python3 -m sglang.launch_server \
  --model-path Qwen/Qwen3-8B \
  --speculative-algorithm DFLASH \
  --speculative-draft-model-path z-lab/Qwen3-8B-DFlash-b16 \
  --attention-backend triton \
  --tp 1 --mem-fraction-static 0.85
  • Target model: Qwen3-8B (15.3 GB) ✅
  • Draft model: DFlash (1.96 GB), attention_backend=triton ✅
  • CUDA graphs: captured for both target and draft ✅
  • Inference: correct output with speculative verify ✅
  • DFLASH verify: accept_length_per_req=[0] logged ✅

Notes

  • The triton_backend.py fix (custom_mask guard) benefits all speculative modes that use non-causal attention, not just DFLASH.
  • The dflash.py RoPE fix is also needed on CUDA when using sgl_kernel — it's not ROCm-specific.
  • The qwen3.py change should ideally be added to all model classes that want DFLASH support (Llama already has it).

…n backend

Add AMD ROCm support for DFLASH speculative decoding by enabling the
Triton attention backend as a draft worker backend.

Changes:
- dflash_worker.py: Add 'triton' to supported_draft_backends, auto-detect
  ROCm and default to triton (FlashInfer unavailable on HIP)
- dflash.py: Fix dummy_q shape in apply_k_rope to match K tensor shape,
  satisfying the RoPE kernel's num_heads % num_kv_heads == 0 check
- triton_backend.py: Guard custom_mask access for DFLASH's non-causal
  (ENCODER_ONLY) attention which doesn't use custom masks
- qwen3.py: Add set_dflash_layers_to_capture for DFLASH aux hidden capture

Tested on AMD Instinct MI300X (gfx942), ROCm 7.0.2:
- Target: Qwen/Qwen3-8B, Draft: z-lab/Qwen3-8B-DFlash-b16
- attention_backend=triton, TP=1
- CUDA graphs: captured successfully for both target and draft
- Inference: correct output with speculative verification

Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances DFLASH and Triton backend support by adding safety checks for custom masks, ensuring RoPE kernel compatibility through shape matching, and implementing auxiliary hidden state capture for Qwen3. It also introduces ROCm-aware fallback logic for draft backends in the DFLASH worker. Review feedback recommends consolidating the fallback logic and removing redundant local imports to improve code maintainability.

Comment on lines +106 to +128
# Use triton on ROCm (no FlashInfer), flashinfer on CUDA
import torch as _torch
draft_backend = "triton" if _torch.version.hip else "flashinfer"
elif draft_backend == "trtllm_mha":
import torch as _torch
_fb = "triton" if _torch.version.hip else "flashinfer"
logger.warning(
"DFLASH draft worker does not support 'trtllm_mha' because the "
"draft path requires non-causal attention. Falling back to "
"'flashinfer'."
"'%s'.", _fb
)
draft_backend = "flashinfer"
draft_backend = _fb
elif draft_backend not in supported_draft_backends:
import torch as _torch
_fb = "triton" if _torch.version.hip else "flashinfer"
logger.warning(
"DFLASH draft worker only supports attention_backend in %s for now, "
"but got %r. Falling back to 'flashinfer'.",
"but got %r. Falling back to '%s'.",
supported_draft_backends,
draft_backend,
_fb,
)
draft_backend = "flashinfer"
draft_backend = _fb
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The local imports of torch as _torch are redundant because torch is already imported at the top of the file (line 6). Additionally, the logic to determine the fallback backend ('triton' on ROCm, 'flashinfer' otherwise) is repeated multiple times. Consolidating this logic improves maintainability and readability.

            # Use triton on ROCm (no FlashInfer), flashinfer on CUDA
            fallback_backend = "triton" if torch.version.hip else "flashinfer"
            if draft_backend is None:
                draft_backend = fallback_backend
            elif draft_backend == "trtllm_mha":
                logger.warning(
                    "DFLASH draft worker does not support 'trtllm_mha' because the "
                    "draft path requires non-causal attention. Falling back to "
                    "'%s'.", fallback_backend
                )
                draft_backend = fallback_backend
            elif draft_backend not in supported_draft_backends:
                logger.warning(
                    "DFLASH draft worker only supports attention_backend in %s for now, "
                    "but got %r. Falling back to '%s'.",
                    supported_draft_backends,
                    draft_backend,
                    fallback_backend,
                )
                draft_backend = fallback_backend

@hnyls2002
Copy link
Copy Markdown
Collaborator

@dcw02 Can you review this?

@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-test test_dflash.py

@github-actions
Copy link
Copy Markdown
Contributor

1-gpu-5090 (1 test): View workflow run

cd test/ && python3 registered/spec/dflash/test_dflash.py

@hnyls2002 hnyls2002 merged commit 9df6107 into sgl-project:main Apr 17, 2026
37 of 54 checks passed
jmamou pushed a commit to jmamou/sglang that referenced this pull request Apr 20, 2026
Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com>
Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com>
Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
Signed-off-by: Andy Luo <andyluo7@users.noreply.github.com>
Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants