[AMD] Enable Piecewise CUDA Graph for AMD GPUs#22299
Draft
hubertlu-tw wants to merge 4 commits intosgl-project:mainfrom
Draft
[AMD] Enable Piecewise CUDA Graph for AMD GPUs#22299hubertlu-tw wants to merge 4 commits intosgl-project:mainfrom
hubertlu-tw wants to merge 4 commits intosgl-project:mainfrom
Conversation
Fix multiple issues preventing PCG from working on AMD GPUs with the aiter backend, and fix a latent Dynamo recompilation crash on all platforms. Universal fix: - Replace hard assert in CUDAPiecewiseBackend with graceful fallback when Dynamo recompiles during serving (triggered when batch size exceeds piecewise_cuda_graph_max_tokens, which on MLA models defaults to 2048 vs chunked_prefill_size) AMD-specific fixes (all guarded by _is_hip): - Wrap aiter MXFP4 GEMM/quant ops as custom ops for torch.compile - Handle DeepSeek MLA dual-attention layer lookup in PCG - Prevent rocminfo subprocess hang during CUDA graph capture - Reduce PCG init time 3.6x by single-pass warmup_compile - Wrap eager fallback with PCG context for oversized batches - Zero attention output and MoE weights for padded tokens - Use raw_num_tokens for EAGLE verify batch zeroing Tested: DeepSeek-R1-MXFP4 + EAGLE + PCG achieves 0.954 GSM8K accuracy and +63% throughput (2677 vs 1642 tok/s) on 4xMI300X/MI355X. Made-with: Cursor
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
kkHuang-amd
reviewed
Apr 8, 2026
kkHuang-amd
reviewed
Apr 8, 2026
| @@ -2755,6 +2678,30 @@ def forward_extend( | |||
| if not skip_attn_backend_init: | |||
| self.attn_backend.init_forward_metadata(forward_batch) | |||
|
|
|||
Collaborator
There was a problem hiding this comment.
Look like the batch exceeds max captured size is a common issue, it only need to do this protection for hip?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Piecewise CUDA Graph (PCG) is enabled by default on NVIDIA GPUs since PR #16331, but is auto-disabled on AMD/ROCm (
is_hip()is one of the auto-disable conditions). This PR fixes multiple issues that prevent PCG from working correctly on AMD GPUs with the aiter attention backend, and also fixes a latent Dynamo recompilation crash that affects all platforms.Tested configurations
Modifications
Universal fix (all platforms)
cuda_piecewise_backend.py— Graceful fallback on Dynamo recompilationReplace a hard
assertwith alogger.warning_once+ eager fallback whenget_pcg_capture_stream()returnsNoneduring serving.Root cause: On MLA models,
piecewise_cuda_graph_max_tokensdefaults to 2048 (vschunked_prefill_sizefor non-MLA). Large prefill batches exceed the captured range, causing Dynamo to recompile and create newCUDAPiecewiseBackendinstances without a capture stream. On non-MLA modelsmax_tokens = chunked_prefill_size, so no batch ever exceeds the range — which is why CUDA has not hit this in practice. This is a latent bug on all platforms: anyone setting--piecewise-cuda-graph-max-tokensbelow--chunked-prefill-sizewould trigger the same crash.AMD-specific fixes (all guarded by
_is_hip)quark_w4a4_mxfp4.py— Custom op wrapping for aiter GEMM/quant opsWrap aiter MXFP4 GEMM/quant functions (
gemm_afp4wfp4,gemm_afp4wfp4_pre_quant,dynamic_mxfp4_quant,fused_gemm_afp4wfp4_split_cat) astorch.ops.sglang.*custom ops viadirect_register_custom_op. This preventstorch.compile(fullgraph=True)from tracing into aiter Python code that useshasattrpatterns Dynamo cannot handle.radix_attention.py+model_runner.py— MLA dual-attention layer fixDeepSeek MLA has two
RadixAttentioninstances per layer:attn_mqa(writes KV cache) andattn_mha(no KV write). PCG'sattention_layerslist only storesattn_mqa. Attachattn_mhaas_pcg_mha_companiononattn_mqaduring layer population, and redirect inunified_attention_with_outputwhen the MHA path is active.piecewise_cuda_graph_runner.py— Preventrocminfosubprocess hangPre-populate aiter chip info env vars (
CU_NUM,GPU_ARCHS) before CUDA graph capture. Without this, aiter lazily callssubprocess.run(rocminfo)during capture when the GPU context is locked, causing a deadlock.piecewise_cuda_graph_runner.py— 3.6x faster PCG initializationReplace the N-iteration
warmup_compileloop with a single warmup using the largest capture size. The capture phase still does per-shape JIT warmup. Reduces DeepSeek-R1-MXFP4 PCG init from ~6 min to ~1 min 40 sec.model_runner.py— Eager fallback wrapping for oversized batchesWhen a batch exceeds
piecewise_cuda_graph_max_tokens, wrap the eager forward withenable_piecewise_cuda_graph()+set_forward_context()so PCG-specific code paths (MoE weight masking, attention zeroing) have access to their layer objects.radix_attention.py+topk.py— Fix accuracy at high parallelismDuring PCG replay, padded positions (from static CUDA graph sizes) contain uninitialized garbage. Two fixes:
topk_ids = -1like NVIDIA kernels do)piecewise_context_manager.py+piecewise_cuda_graph_runner.py+radix_attention.py— Fix EAGLE+PCG accuracyUse
raw_num_tokens(pre-padding count from PCG runner) instead ofextend_num_tokensfor attention output zeroing.extend_num_tokensisNonefor EAGLETARGET_VERIFYbatches, causing the zeroing to be skipped and garbage to propagate.Accuracy Tests
GSM8K benchmark (5-shot, 1319 questions,
--parallel 1319):Additional models (registered CI tests):
Benchmarking
Baseline (DeepSeek-R1-MXFP4 + EAGLE, TP=4, EP=4, no PCG):
With PCG + EAGLE (
--piecewise-cuda-graph-max-tokens 8192, recommended):Output throughput 4222 tok/s, Total 14070 tok/s, TPOT 78.74ms
Without EAGLE (PCG only):
Note: Without EAGLE, the default MLA cap of 2048 means PCG only covers small
decode batches. Raising
--piecewise-cuda-graph-max-tokensto 8192 isrecommended for MLA models to also capture prefill/verify batches.
Launch Commands
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci