Skip to content

[AMD] Enable Piecewise CUDA Graph for AMD GPUs#22299

Draft
hubertlu-tw wants to merge 4 commits intosgl-project:mainfrom
hubertlu-tw:pcg_enablement
Draft

[AMD] Enable Piecewise CUDA Graph for AMD GPUs#22299
hubertlu-tw wants to merge 4 commits intosgl-project:mainfrom
hubertlu-tw:pcg_enablement

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Collaborator

Motivation

Piecewise CUDA Graph (PCG) is enabled by default on NVIDIA GPUs since PR #16331, but is auto-disabled on AMD/ROCm (is_hip() is one of the auto-disable conditions). This PR fixes multiple issues that prevent PCG from working correctly on AMD GPUs with the aiter attention backend, and also fixes a latent Dynamo recompilation crash that affects all platforms.

Tested configurations

  • DeepSeek-R1-MXFP4-Preview — TP=4, EP=4, EAGLE speculative decoding, FP8 KV cache, aiter backend
  • Qwen3.5-397B-A17B-MXFP4 — TP=4, aiter backend
  • 1-GPU models — Qwen2.5-VL-7B, InternVL2.5-8B, Qwen2.5-VL-3B

Modifications

Universal fix (all platforms)

  1. cuda_piecewise_backend.py — Graceful fallback on Dynamo recompilation
    Replace a hard assert with a logger.warning_once + eager fallback when get_pcg_capture_stream() returns None during serving.
    Root cause: On MLA models, piecewise_cuda_graph_max_tokens defaults to 2048 (vs chunked_prefill_size for non-MLA). Large prefill batches exceed the captured range, causing Dynamo to recompile and create new CUDAPiecewiseBackend instances without a capture stream. On non-MLA models max_tokens = chunked_prefill_size, so no batch ever exceeds the range — which is why CUDA has not hit this in practice. This is a latent bug on all platforms: anyone setting --piecewise-cuda-graph-max-tokens below --chunked-prefill-size would trigger the same crash.

AMD-specific fixes (all guarded by _is_hip)

  1. quark_w4a4_mxfp4.py — Custom op wrapping for aiter GEMM/quant ops
    Wrap aiter MXFP4 GEMM/quant functions (gemm_afp4wfp4, gemm_afp4wfp4_pre_quant, dynamic_mxfp4_quant, fused_gemm_afp4wfp4_split_cat) as torch.ops.sglang.* custom ops via direct_register_custom_op. This prevents torch.compile(fullgraph=True) from tracing into aiter Python code that uses hasattr patterns Dynamo cannot handle.

  2. radix_attention.py + model_runner.py — MLA dual-attention layer fix
    DeepSeek MLA has two RadixAttention instances per layer: attn_mqa (writes KV cache) and attn_mha (no KV write). PCG's attention_layers list only stores attn_mqa. Attach attn_mha as _pcg_mha_companion on attn_mqa during layer population, and redirect in unified_attention_with_output when the MHA path is active.

  3. piecewise_cuda_graph_runner.py — Prevent rocminfo subprocess hang
    Pre-populate aiter chip info env vars (CU_NUM, GPU_ARCHS) before CUDA graph capture. Without this, aiter lazily calls subprocess.run(rocminfo) during capture when the GPU context is locked, causing a deadlock.

  4. piecewise_cuda_graph_runner.py — 3.6x faster PCG initialization
    Replace the N-iteration warmup_compile loop with a single warmup using the largest capture size. The capture phase still does per-shape JIT warmup. Reduces DeepSeek-R1-MXFP4 PCG init from ~6 min to ~1 min 40 sec.

  5. model_runner.py — Eager fallback wrapping for oversized batches
    When a batch exceeds piecewise_cuda_graph_max_tokens, wrap the eager forward with enable_piecewise_cuda_graph() + set_forward_context() so PCG-specific code paths (MoE weight masking, attention zeroing) have access to their layer objects.

  6. radix_attention.py + topk.py — Fix accuracy at high parallelism
    During PCG replay, padded positions (from static CUDA graph sizes) contain uninitialized garbage. Two fixes:

    • Zero attention output for padded positions after varlen attention kernels
    • Zero MoE routing weights for padded tokens (aiter kernels don't handle topk_ids = -1 like NVIDIA kernels do)
  7. piecewise_context_manager.py + piecewise_cuda_graph_runner.py + radix_attention.py — Fix EAGLE+PCG accuracy

Use raw_num_tokens (pre-padding count from PCG runner) instead of extend_num_tokens for attention output zeroing. extend_num_tokens is None for EAGLE TARGET_VERIFY batches, causing the zeroing to be skipped and garbage to propagate.

Accuracy Tests

GSM8K benchmark (5-shot, 1319 questions, --parallel 1319):

Model Config Accuracy Threshold
DeepSeek-R1-MXFP4 TP=4, EP=4, PCG, fp8 KV 0.941 >= 0.94
DeepSeek-R1-MXFP4 + EAGLE TP=4, EP=4, PCG, EAGLE, fp8 KV 0.954 >= 0.94

Additional models (registered CI tests):

Model Config Metric Result
Qwen3.5-397B-A17B-MXFP4 TP=4, PCG Serving OK Pass
Qwen2.5-VL-7B-Instruct TP=1, PCG MGSM 0.872 (>= 0.70)
InternVL2.5-8B TP=1, PCG MGSM 0.704 (>= 0.70)
Qwen2.5-VL-3B (embedding) TP=1, PCG diff 0.0 (atol=1e-2)

Benchmarking

Baseline (DeepSeek-R1-MXFP4 + EAGLE, TP=4, EP=4, no PCG):

  • Accuracy: 0.950, Output throughput: 1642 token/s

With PCG + EAGLE (--piecewise-cuda-graph-max-tokens 8192, recommended):

  • Accuracy: 0.954, Output throughput: 2657-2677 token/s (+63% throughput)
  • bench_serving (random 3500in/1500out, 768 prompts, 448 concurrency):
    Output throughput 4222 tok/s, Total 14070 tok/s, TPOT 78.74ms

Without EAGLE (PCG only):

  • No PCG: Accuracy 0.947, Throughput 2019 tok/s
  • With PCG (max-tokens 2048): Accuracy 0.945, Throughput 1898 tok/s

Note: Without EAGLE, the default MLA cap of 2048 means PCG only covers small
decode batches. Raising --piecewise-cuda-graph-max-tokens to 8192 is
recommended for MLA models to also capture prefill/verify batches.

Launch Commands

# DeepSeek-R1-MXFP4 + EAGLE + PCG (recommended)
SGLANG_USE_AITER=1 SGLANG_MOE_PADDING=1 \
python3 -m sglang.launch_server \
  --model-path /path/to/DeepSeek-R1-MXFP4-Preview \
  --tp-size 4 --ep-size 4 \
  --speculative-algorithm EAGLE \
  --speculative-num-steps 3 --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4 --speculative-attention-mode decode \
  --speculative-draft-model-path /path/to/DeepSeek-R1-NextN \
  --trust-remote-code --mem-fraction-static 0.95 \
  --attention-backend aiter --chunked-prefill-size 131072 \
  --disable-radix-cache --kv-cache-dtype fp8_e4m3 \
  --max-running-requests 448 \
  --enforce-piecewise-cuda-graph --piecewise-cuda-graph-compiler eager \
  --piecewise-cuda-graph-max-tokens 8192

# GSM8K accuracy test
python3 benchmark/gsm8k/bench_sglang.py \
  --num-questions 1319 --parallel 1319 --num-shots 5 --port 9001

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Fix multiple issues preventing PCG from working on AMD GPUs with the
aiter backend, and fix a latent Dynamo recompilation crash on all
platforms.

Universal fix:
- Replace hard assert in CUDAPiecewiseBackend with graceful fallback
  when Dynamo recompiles during serving (triggered when batch size
  exceeds piecewise_cuda_graph_max_tokens, which on MLA models defaults
  to 2048 vs chunked_prefill_size)

AMD-specific fixes (all guarded by _is_hip):
- Wrap aiter MXFP4 GEMM/quant ops as custom ops for torch.compile
- Handle DeepSeek MLA dual-attention layer lookup in PCG
- Prevent rocminfo subprocess hang during CUDA graph capture
- Reduce PCG init time 3.6x by single-pass warmup_compile
- Wrap eager fallback with PCG context for oversized batches
- Zero attention output and MoE weights for padded tokens
- Use raw_num_tokens for EAGLE verify batch zeroing

Tested: DeepSeek-R1-MXFP4 + EAGLE + PCG achieves 0.954 GSM8K accuracy
and +63% throughput (2677 vs 1642 tok/s) on 4xMI300X/MI355X.

Made-with: Cursor
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Comment thread python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py Outdated
@@ -2755,6 +2678,30 @@ def forward_extend(
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like the batch exceeds max captured size is a common issue, it only need to do this protection for hip?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants