Skip to content

[AMD] Default SGLANG_USE_AITER_AR to false to avoid HIP graph capture invalidation#23581

Open
andyluo7 wants to merge 1 commit intosgl-project:mainfrom
andyluo7:fix-aiter-ar-hip-graph
Open

[AMD] Default SGLANG_USE_AITER_AR to false to avoid HIP graph capture invalidation#23581
andyluo7 wants to merge 1 commit intosgl-project:mainfrom
andyluo7:fix-aiter-ar-hip-graph

Conversation

@andyluo7
Copy link
Copy Markdown
Contributor

Summary

AiterCustomAllreduce (the default custom all-reduce on AMD HIP when SGLANG_USE_AITER_AR=true, which is the current default) launches helper kernels on an internal stream during HIP graph capture, which invalidates the captured graph (hipErrorStreamCaptureInvalidated). The next decode replay then dispatches a broken graph, triggering HSA_STATUS_ERROR_EXCEPTION: An HSAIL operation resulted in a hardware exception. code: 0x1016 and SIGABRT'ing all TP scheduler subprocesses.

This was first observed with the tencent/Hy3-preview model (PR #23533) on MI300X and MI355X at TP=8 with the standard launcher arguments from the model card. Eager mode (--disable-cuda-graph) works fine, confirming the model code itself is correct on AMD — the crash is specific to CUDA-graph capture/replay.

Filed full diagnosis with reproducer in #23580.

Fix

Change the default of SGLANG_USE_AITER_AR from "true" to "false". With this default, dispatch_custom_allreduce() picks sglang's own CustomAllreduce on HIP, which respects the captured stream and works correctly. AITER's other fast paths (attention, MoE, RMSNorm, fused_qk_norm) remain enabled, so performance is preserved (slightly improved in our benchmarks because we keep AITER attention while using a stream-safe all-reduce).

Users who know their workload doesn't trigger the issue (e.g., NVIDIA, or AMD without large CUDA-graph capture) can opt back in with SGLANG_USE_AITER_AR=1.

Once AITER's HIP graph capture path is fixed upstream, this default can be flipped back.

Validation

Run on tencent/Hy3-preview with PR #23533 file overlays applied to rocm/sgl-dev:v0.5.10.post1-rocm720-mi30x-20260423 and rocm/sgl-dev:v0.5.10.post1-rocm720-mi35x-20260423, transformers==5.6.1, TP=8, default launcher (no --disable-cuda-graph):

Hardware Workload Before (SGLANG_USE_AITER_AR=true) After (this patch)
MI300X TP=8 Hy3 single long, 512 tok ❌ HSA exception 0x1016, SIGABRT 34.9 tok/s
MI300X TP=8 Hy3 c=4, 16 reqs ❌ crash ✅ 128.8 tok/s
MI300X TP=8 Hy3 c=8, 32 reqs ❌ crash 250.7 tok/s
MI355X TP=8 Hy3 single long, 512 tok ❌ crash 39.6 tok/s
MI355X TP=8 Hy3 c=8, 32 reqs ❌ crash 295.7 tok/s

For reference, vLLM's piecewise CUDA-graph mode (in vllm-project/vllm#40681) tolerates AITER's all-reduce because it captures graphs per-layer, not monolithically; but SGLang's default monolithic capture cannot. Hence this defaults change is the right scope for SGLang.

Refs

Future work

The proper upstream fix is to make AiterCustomAllreduce HIP-graph-safe in AITER itself. This PR is an interim default change so that AMD users of the default lmsysorg/sglang and rocm/sgl-dev images stop hitting the crash. Once AITER's HIP graph capture is fixed, the default can be flipped back to true.

… invalidation

AiterCustomAllreduce launches helper kernels on an internal stream during
HIP graph capture, which invalidates the captured graph
(hipErrorStreamCaptureInvalidated) and triggers HSA_STATUS_ERROR_EXCEPTION
0x1016 on the first decode replay.

This was first observed with the tencent/Hy3-preview model (PR sgl-project#23533) on
MI300X and MI355X with TP=8 and the standard launcher arguments from the
model card. Eager mode (--disable-cuda-graph) works fine, confirming the
model code is correct on AMD; the crash is specific to CUDA-graph replay.

Bisected to AITER's all-reduce: setting SGLANG_USE_AITER_AR=0 (which makes
dispatch_custom_allreduce() pick sglang's own CustomAllreduce instead) is
sufficient to fix the crash. AITER's other fast paths (attention, MoE,
RMSNorm, fused_qk_norm) remain enabled and performance is preserved
(slightly improved in our benchmarks):

| Hardware    | Workload                | Before this patch     | With this patch |
|-------------|-------------------------|-----------------------|-----------------|
| MI300X TP=8 | Hy3 single long 512 tok | crash (HSA exception) | 34.9 tok/s      |
| MI300X TP=8 | Hy3 c=8, 32 reqs        | crash (HSA exception) | 250.7 tok/s     |
| MI355X TP=8 | Hy3 single long 512 tok | crash (HSA exception) | 39.6 tok/s      |
| MI355X TP=8 | Hy3 c=8, 32 reqs        | crash (HSA exception) | 295.7 tok/s     |

Until AITER's all-reduce is fixed for HIP graph capture, default to
sglang's CustomAllreduce on HIP.  Users who know their workload is
unaffected can opt back into AITER's implementation with
SGLANG_USE_AITER_AR=1.

Refs: sgl-project#23580 (bug report), sgl-project#23533 (Hy3-preview model PR that triggered it)
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@sunway513
Copy link
Copy Markdown

Update from AITER side — root cause has been pinned and is not an AITER bug. Per @TennyWang1223 (ROCm/aiter#2941) and @Jacob0226 (ROCm/aiter#2857), the crash is a known ROCm 7.2.0 hipEventQuery runtime bug: cross-thread call ignores THREAD_LOCAL capture mode, so the NCCL watchdog invalidates the in-flight HIP graph capture. AITER's IPCBufferPool change in v0.1.12.post1 only widened the race window.

The runtime fix is in ROCm ≥ 7.2.1.

Recommended fix for SGLang base image (zero C++ rebuild, no Python change):

- BASE_IMAGE_*_ROCM720: rocm/pytorch:rocm7.2.0_*
+ BASE_IMAGE_*_ROCM720: rocm/pytorch:rocm7.2.2_ubuntu22.04_py3.10_pytorch_release_2.9.1

The default SGLANG_USE_AITER_AR=false change in #23581 is still a good defensive default for users stuck on ROCm 7.2.0, but with the base image bump, AITER custom allreduce is safe to re-enable on ROCm 7.2.1+.

Refs: ROCm/aiter#2941, ROCm/aiter#2857, pytorch/pytorch#176251 (reverted after 7.2.1 shipped).

@andyluo7
Copy link
Copy Markdown
Contributor Author

andyluo7 commented May 4, 2026

Thanks @sunway513 for the definitive root-cause investigation 🙏

Filed the base-image bump as #24151rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1rocm7.2.2_ubuntu22.04_py3.10_pytorch_release_2.9.1 for both BASE_IMAGE_942_ROCM720 and BASE_IMAGE_950_ROCM720, with refs to ROCm/aiter#2941, ROCm/aiter#2857, and pytorch/pytorch#176251 in the description.

Keeping the defensive SGLANG_USE_AITER_AR=false default in this PR for users still on ROCm 7.2.0 environments. Once #24151 lands, AITER custom all-reduce is safe to re-enable on the bumped base image.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants