[AMD] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist#21986
[AMD] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist#21986HaiShaw merged 4 commits intosgl-project:mainfrom
Conversation
The activation gate in `apply_aiter_all_reduce_fusion` used strict less-than (`<`) for the byte-size threshold, while AITER's internal `should_custom_ar` uses less-than-or-equal (`<=`). For the common case of hidden_size=4096 with bf16 at 8192 tokens, the total bytes exactly equal the threshold (67,108,864), so `<` rejected it and the fused kernel never activated. Change `<` to `<=` so SGLang's gate matches AITER's boundary, enabling the fused allreduce+RMSNorm kernel for this shape. Depends on: ROCm/aiter#2586 Made-with: Cursor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
…used AR+RMSNorm
- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
regression test for ROCm/aiter#2586.
Made-with: Cursor
080bf23 to
306abe2
Compare
|
/tag-and-rerun-ci |
|
@amd-bot ci-status |
CI Status for PR #21986PR: [AMD][No-Merge] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist
DetailsNone of these failures are related to this PR's changes. The PR modifies 4 files — all scoped to AMD ROCm fused allreduce+RMSNorm logic and its tests:
The actual failures fall into three categories:
Verdict: No action needed from this PR. All failures are pre-existing infrastructure issues or cancelled runs.Generated by amd-bot using Claude Code CLI |
…dim allowlist (sgl-project#21986) Co-authored-by: HAI <hixiao@gmail.com>
…dim allowlist (sgl-project#21986) Co-authored-by: HAI <hixiao@gmail.com>
Two changes to the
--enable-aiter-allreduce-fusionpath:communicator.py(commit449f7c293): Fix an off-by-one (<→<=) inapply_aiter_all_reduce_fusionso the fused kernel activates at the exact boundary size, matching AITER's internalshould_custom_ar(inp_size <= max_size/2).parallel_state.py(this PR): Remove the hardcodedhidden_dim ∈ {512, 1024, 2048, 4096}allowlist for 1-stage vs 2-stage selection. Keep the 128 KB byte-size threshold (the 1-stage kernel is capped at 80 tokens viakMaxBlocks) but let AITER's C++ dispatch decide which hidden_dims are supported.Together these eliminate the per-model maintenance burden (no more manual hidden_dim additions for new models) and ensure AITER's own heuristics are respected.
Depends on:
accuracy in the 1-stage fused kernel (bf16 round-trip before residual addition).
hidden_dimdivisible bypack_size(no more hardcoded template set), and the 2-stage kernel usesn_packs-based dispatch (no moren_bytes % 1024alignment requirement).Subsumes: #21947 — that PR adds
2880to the Python-side allowlist. With this change, the allowlist is removed entirely, so all AITER-supported hidden_dims work automatically.Motivation
Problem 1: Off-by-one in outer gate
The activation gate in
communicator.pyused<for the byte-size threshold while AITER uses<=. Forhidden_size=4096with bf16 at 8192 tokens,total_bytesexactly equals the threshold (67,108,864), so<rejected it.Problem 2: Redundant hidden_dim allowlist
The 1-stage vs 2-stage decision in
parallel_state.pygated onhidden_dim ∈ {512, 1024, 2048, 4096}. This was redundant at three levels:parallel_state.pyhidden_dim ∈ {512, 1024, 2048, 4096}custom_all_reduce.cuhn == 4096 || n == 2048 || ...custom_all_reduce.cuhn % pack_size == 0 && n/pack_size <= 1024AITER's C++ dispatch already enforces which hidden_dims have 1-stage support. Passing
use_1stage=Truefor an unsupported dim is safe — AITER overrides it tofalseand falls through to 2-stage. The Python-side allowlist only created a maintenance burden (cf. PR #21947).The 128 KB threshold is kept (but the allowlist is not)
The 1-stage fused kernel launches one CTA per token and is hard-capped at
kMaxBlocks = 80tokens. Forhidden_dim=4096with bf16:80 × 4096 × 2 = 655,360 bytes. The 128 KB threshold safely ensures onlysmall decode-like batches take the 1-stage path; larger prefill batches
automatically fall through to the 2-stage kernel.
Modifications
File 1:
python/sglang/srt/layers/communicator.pydef apply_aiter_all_reduce_fusion(input_tensor: torch.Tensor): n = input_tensor.shape[-1] total_bytes = input_tensor.numel() * input_tensor.element_size() + # Aiter's should_custom_ar uses <= max_size/2 (64 MB); match that boundary. return ( _use_aiter and total_bytes > 0 and n <= 16384 - and total_bytes < 8 * 1024 * 8192 + and total_bytes <= 8 * 1024 * 8192 and get_tensor_model_parallel_world_size() != 6 and not is_dp_attention_enabled() and get_global_server_args().enable_aiter_allreduce_fusion )File 2:
python/sglang/srt/distributed/parallel_state.pyAccuracy Tests
Setup
SGLANG_USE_AITER=1,--attention-backend aitermainbranch (a530b92bf), includes bothROCm/aiter#2586 (1-stage accuracy fix) and
ROCm/aiter#2453 (allreduce refactor).
Full rebuild with
GPU_ARCHS=gfx950 MAX_JOBS=192 PREBUILD_KERNELS=1.GSM8K (1319 questions, 5-shot, parallel=1319)
Three rounds per configuration, with full PID/GPU cleanup between cases.
--enable-aiter-allreduce-fusionNo accuracy regression. Fusion ON is +2.3pp above baseline, confirming the AITER accuracy fix (#2586) is correctly integrated and the fused path produces numerically sound results.
Commands
Benchmarking
Serving Benchmark (random, 8 × 8192-in / 1024-out, concurrency=1)
Kernel Microbenchmark (TP=4, bf16)
The fused kernel is faster across all tested shapes and hidden_dims:
Decode (graph-captured)
With AITER #2453 applied, all hidden_dims
work — including GPT-OSS's
hidden_dim=2880:Prefill (eager)
All hidden_dims benefit — confirming that the allowlist was unnecessarily restrictive.
GPT-OSS (
hidden_dim=2880) gets up to 2.10× speedup on decode.Trace Verification
After the threshold fix, the EXTEND (prefill) trace on TP-0 confirms fused
kernel activation:
cross_device_reduce_2stage(unfused AR)add_rmsnorm_quant_kernel(256-wide)reduce_scatter_cross_device_store(fused step 1)local_device_load_rmsnorm_naive(fused step 2)119 out of 121 allreduce sites now use the fused path. The remaining 2 are
smaller-hidden-dim layers (64-wide) that fall below AITER's internal threshold.
Summary
CC: @kkHuang-amd @HaiShaw