Skip to content

[AMD] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist#21986

Merged
HaiShaw merged 4 commits intosgl-project:mainfrom
hubertlu-tw:fused_ar_activation
Apr 12, 2026
Merged

[AMD] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist#21986
HaiShaw merged 4 commits intosgl-project:mainfrom
hubertlu-tw:fused_ar_activation

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Collaborator

Two changes to the --enable-aiter-allreduce-fusion path:

  1. communicator.py (commit 449f7c293): Fix an off-by-one (<<=) in apply_aiter_all_reduce_fusion so the fused kernel activates at the exact boundary size, matching AITER's internal should_custom_ar (inp_size <= max_size/2).

  2. parallel_state.py (this PR): Remove the hardcoded hidden_dim ∈ {512, 1024, 2048, 4096} allowlist for 1-stage vs 2-stage selection. Keep the 128 KB byte-size threshold (the 1-stage kernel is capped at 80 tokens via kMaxBlocks) but let AITER's C++ dispatch decide which hidden_dims are supported.

Together these eliminate the per-model maintenance burden (no more manual hidden_dim additions for new models) and ensure AITER's own heuristics are respected.

Depends on:

  • ROCm/aiter#2586 — fixes numerical
    accuracy in the 1-stage fused kernel (bf16 round-trip before residual addition).
  • ROCm/aiter#2453 — refactors the fused allreduce dispatch so the 1-stage kernel accepts any hidden_dim divisible by pack_size (no more hardcoded template set), and the 2-stage kernel uses n_packs-based dispatch (no more n_bytes % 1024 alignment requirement).

Subsumes: #21947 — that PR adds 2880 to the Python-side allowlist. With this change, the allowlist is removed entirely, so all AITER-supported hidden_dims work automatically.

Motivation

Problem 1: Off-by-one in outer gate

The activation gate in communicator.py used < for the byte-size threshold while AITER uses <=. For hidden_size=4096 with bf16 at 8192 tokens, total_bytes exactly equals the threshold (67,108,864), so < rejected it.

Problem 2: Redundant hidden_dim allowlist

The 1-stage vs 2-stage decision in parallel_state.py gated on hidden_dim ∈ {512, 1024, 2048, 4096}. This was redundant at three levels:

Layer Where What it checked
Python (SGLang) parallel_state.py hidden_dim ∈ {512, 1024, 2048, 4096}
C++ (old AITER) custom_all_reduce.cuh n == 4096 || n == 2048 || ...
C++ (new AITER, #2453) custom_all_reduce.cuh n % pack_size == 0 && n/pack_size <= 1024

AITER's C++ dispatch already enforces which hidden_dims have 1-stage support. Passing use_1stage=True for an unsupported dim is safe — AITER overrides it to false and falls through to 2-stage. The Python-side allowlist only created a maintenance burden (cf. PR #21947).

The 128 KB threshold is kept (but the allowlist is not)

The 1-stage fused kernel launches one CTA per token and is hard-capped at
kMaxBlocks = 80 tokens. For hidden_dim=4096 with bf16:
80 × 4096 × 2 = 655,360 bytes. The 128 KB threshold safely ensures only
small decode-like batches take the 1-stage path; larger prefill batches
automatically fall through to the 2-stage kernel.

Modifications

File 1: python/sglang/srt/layers/communicator.py

 def apply_aiter_all_reduce_fusion(input_tensor: torch.Tensor):
     n = input_tensor.shape[-1]
     total_bytes = input_tensor.numel() * input_tensor.element_size()
+    # Aiter's should_custom_ar uses <= max_size/2 (64 MB); match that boundary.
     return (
         _use_aiter
         and total_bytes > 0
         and n <= 16384
-        and total_bytes < 8 * 1024 * 8192
+        and total_bytes <= 8 * 1024 * 8192
         and get_tensor_model_parallel_world_size() != 6
         and not is_dp_attention_enabled()
         and get_global_server_args().enable_aiter_allreduce_fusion
     )

File 2: python/sglang/srt/distributed/parallel_state.py

-        # 1-stage policy for fused AR+RMSNorm:
-        # 1) Explicit env override wins.
-        # 2) Deterministic inference forces 1-stage for reproducibility.
-        # 3) Otherwise follow AITER's heuristic (small payloads only).
+        # 1-stage vs 2-stage selection for fused AR+RMSNorm:
+        # The 1-stage kernel launches one block per token and is capped at
+        # 80 tokens (kMaxBlocks).  Guard with a byte threshold so large
+        # prefill batches fall through to the 2-stage kernel instead of
+        # hitting a runtime error.  AITER's C++ dispatch already gates
+        # which hidden_dims have valid 1-stage support.
         if envs.SGLANG_USE_1STAGE_ALLREDUCE.is_set():
             use_1stage_ar = envs.SGLANG_USE_1STAGE_ALLREDUCE.get()
-        elif envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get():
-            use_1stage_ar = True
         else:
             total_bytes = input_.numel() * input_.element_size()
-            hidden_dim = input_.shape[-1]
-            use_1stage_ar = total_bytes <= 128 * 1024 and hidden_dim in {
-                512,
-                1024,
-                2048,
-                4096,
-            }
+            use_1stage_ar = total_bytes <= 128 * 1024

Accuracy Tests

Setup

  • Model: Qwen3.5-397B-A17B-FP8 (60 layers, hidden=4096, 512 experts, MoE)
  • Hardware: 4× AMD MI355X / gfx950 (TP=4, XGMI fully connected)
  • Backend: SGLANG_USE_AITER=1, --attention-backend aiter
  • AITER version: main branch (a530b92bf), includes both
    ROCm/aiter#2586 (1-stage accuracy fix) and
    ROCm/aiter#2453 (allreduce refactor).
    Full rebuild with GPU_ARCHS=gfx950 MAX_JOBS=192 PREBUILD_KERNELS=1.

GSM8K (1319 questions, 5-shot, parallel=1319)

Three rounds per configuration, with full PID/GPU cleanup between cases.

Round A: --enable-aiter-allreduce-fusion B: Baseline (no flag)
1 95.0% 92.8%
2 94.2% 91.7%
3 94.9% 92.6%
Mean 94.7% 92.4%

No accuracy regression. Fusion ON is +2.3pp above baseline, confirming the AITER accuracy fix (#2586) is correctly integrated and the fused path produces numerically sound results.

Commands

# Server (Case A — fusion ON)
SGLANG_USE_AITER=1 python3 -m sglang.launch_server \
  --model-path /data2/Qwen/Qwen3.5-397B-A17B-FP8/ \
  --tp 4 --attention-backend aiter --trust-remote-code \
  --watchdog-timeout 1200 --mem-fraction-static 0.9 \
  --host 0.0.0.0 --port 9000 \
  --enable-aiter-allreduce-fusion

# Server (Case B — fusion OFF)
SGLANG_USE_AITER=1 python3 -m sglang.launch_server \
  --model-path /data2/Qwen/Qwen3.5-397B-A17B-FP8/ \
  --tp 4 --attention-backend aiter --trust-remote-code \
  --watchdog-timeout 1200 --mem-fraction-static 0.9 \
  --host 0.0.0.0 --port 9000

# GSM8K benchmark (run 3 rounds per case)
python3 benchmark/gsm8k/bench_sglang.py \
  --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000

# Serving benchmark
python3 -m sglang.bench_serving \
  --host localhost --port 9000 \
  --model /data2/Qwen/Qwen3.5-397B-A17B-FP8/ \
  --dataset-name random \
  --random-input 8192 --random-output 1024 \
  --random-range-ratio 1.0 \
  --max-concurrency 1 --num-prompt 8

Benchmarking

Serving Benchmark (random, 8 × 8192-in / 1024-out, concurrency=1)

Metric A: Fusion ON B: Fusion OFF Delta
Total token throughput (tok/s) 766.4 739.1 +3.7%
Output token throughput (tok/s) 85.2 82.1 +3.7%
Peak output throughput (tok/s) 88.0 85.0 +3.5%
Median E2E Latency (ms) 12,019 12,476 −3.7%
Median TTFT (ms) 297.5 295.2 +0.8%
Median TPOT (ms) 11.46 11.90 −3.7%
Median ITL (ms) 11.46 11.89 −3.6%

Kernel Microbenchmark (TP=4, bf16)

The fused kernel is faster across all tested shapes and hidden_dims:

torchrun --nproc_per_node=4 \
  benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py \
  --dtype bf16 \
  --decode-shapes 1x4096,4x4096,16x4096,32x4096,1x5120,16x5120,1x7168,16x7168,1x8192,1x14336,4x16384 \
  --prefill-shapes 128x4096,128x5120,128x7168,128x16384,8192x4096 \
  --warmup 10 --iters 30 --repeats 5

Decode (graph-captured)

With AITER #2453 applied, all hidden_dims
work — including GPT-OSS's hidden_dim=2880:

Shape Unfused p50 (µs) Fused p50 (µs) Speedup
1×2880 20.4 10.4 1.96×
4×2880 21.7 10.5 2.07×
16×2880 22.8 10.8 2.10×
1×4096 20.6 11.7 1.76×
4×4096 21.4 10.3 2.08×
16×4096 23.2 11.6 2.00×
1×5120 22.2 10.4 2.13×
16×5120 22.8 14.2 1.61×
1×6144 21.9 10.9 2.00×
16×6144 23.9 14.5 1.65×
1×7168 23.7 10.6 2.24×
16×7168 24.4 14.8 1.64×
1×8192 23.2 10.8 2.14×
16×8192 25.4 14.8 1.72×

Prefill (eager)

Shape Unfused p50 (µs) Fused p50 (µs) Speedup
128×2880 36.0 25.4 1.42×
128×4096 36.8 25.5 1.44×
128×5120 37.8 25.5 1.49×
128×6144 37.9 28.2 1.34×
128×7168 38.8 30.3 1.28×
128×8192 41.3 32.5 1.27×

All hidden_dims benefit — confirming that the allowlist was unnecessarily restrictive.
GPT-OSS (hidden_dim=2880) gets up to 2.10× speedup on decode.

Trace Verification

After the threshold fix, the EXTEND (prefill) trace on TP-0 confirms fused
kernel activation:

Kernel Before fix After fix Change
cross_device_reduce_2stage (unfused AR) 121 2 Replaced by fused
add_rmsnorm_quant_kernel (256-wide) 120 1 Replaced by fused
reduce_scatter_cross_device_store (fused step 1) 0 119 NEW
local_device_load_rmsnorm_naive (fused step 2) 0 119 NEW

119 out of 121 allreduce sites now use the fused path. The remaining 2 are
smaller-hidden-dim layers (64-wide) that fall below AITER's internal threshold.

Summary

Aspect Result
Accuracy (GSM8K) No regression (94.7% vs 92.4% baseline, +2.3pp)
Throughput +3.7% total tok/s
Decode latency (TPOT) −3.7% (11.46ms vs 11.90ms)
Prefill latency (TTFT) Neutral
Kernel speedup (decode) 1.3–2.2× across all decode shapes and hidden_dims
Maintenance No more per-model hidden_dim additions needed

CC: @kkHuang-amd @HaiShaw

The activation gate in `apply_aiter_all_reduce_fusion` used strict
less-than (`<`) for the byte-size threshold, while AITER's internal
`should_custom_ar` uses less-than-or-equal (`<=`). For the common
case of hidden_size=4096 with bf16 at 8192 tokens, the total bytes
exactly equal the threshold (67,108,864), so `<` rejected it and
the fused kernel never activated.

Change `<` to `<=` so SGLang's gate matches AITER's boundary,
enabling the fused allreduce+RMSNorm kernel for this shape.

Depends on: ROCm/aiter#2586

Made-with: Cursor
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added documentation Improvements or additions to documentation amd labels Apr 3, 2026
…used AR+RMSNorm

- parallel_state.py: Remove hardcoded hidden_dim allowlist {512,1024,2048,4096}
  for 1-stage kernel selection; keep 128KB byte threshold. AITER's C++ dispatch
  already gates which dims are supported (ROCm/aiter#2453).
- benchmark_fused_ar_rms_amd.py: Add hidden_dim=2880 (GPT-OSS) to default
  decode and prefill shapes.
- test_aiter_allreduce_fusion_amd.py: Add multi-hidden-dim correctness test
  covering 2880/4096/5120/6144/7168/8192, and bit-exact residual accuracy
  regression test for ROCm/aiter#2586.

Made-with: Cursor
@hubertlu-tw hubertlu-tw force-pushed the fused_ar_activation branch from 080bf23 to 306abe2 Compare April 3, 2026 02:19
@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Apr 5, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 5, 2026
@HaiShaw HaiShaw changed the title [AMD] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist [AMD][No-Merge] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist Apr 5, 2026
@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Apr 12, 2026

@amd-bot ci-status

@amd-bot
Copy link
Copy Markdown

amd-bot commented Apr 12, 2026

@HaiShaw

CI Status for PR #21986

PR: [AMD][No-Merge] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist
Changed files: benchmark/kernels/all_reduce/benchmark_fused_ar_rms_amd.py (+2/-2), python/sglang/srt/distributed/parallel_state.py (+7/-14), python/sglang/srt/layers/communicator.py (+2/-1), test/registered/ops/test_aiter_allreduce_fusion_amd.py (+251/-38)

Job Error Related? Explanation Log
build-and-test (XPU) ValueError: Cannot find model module 'DeepseekOCRForCausalLM' — missing addict/matplotlib packages 🟢 Unlikely XPU/DeepSeek-OCR test missing pip dependencies. PR only changes AMD allreduce code. Log
build-test (all) (Xeon) KeyError: 'sglang' in ATTENTION_CLASSES[config._attn_implementation] 🟢 Unlikely DeepSeek model's custom modeling_deepseek.py doesn't recognize 'sglang' attention backend. CPU/Xeon test infrastructure issue unrelated to AMD allreduce changes. Log
stage-b-test-2-npu-a2 (1) OSError: ... does not appear to have a file named modeling_deepseek.py 🟢 Unlikely Missing dynamic module file in NPU cached model directory. NPU infrastructure issue. Log
stage-b-test-16-npu-a3 KeyError: 'sglang' in ATTENTION_CLASSES 🟢 Unlikely Same sglang attention backend key error on NPU. Infrastructure issue unrelated to PR. Log
notebook-finish Upstream run-all-notebooks was externally cancelled 🟢 Unlikely Workflow was cancelled mid-execution (likely by a newer push or manual cancel). Not a code failure. Log
pr-test-finish Upstream stage-b/c jobs all cancelled 🟢 Unlikely Gate job failed because all NVIDIA test stages were cancelled, not because of test failures. Log
pr-test-amd-finish Upstream stage-b/c jobs all cancelled 🟢 Unlikely Gate job failed because all AMD test stages were cancelled, not because of test failures. Log
pr-test-npu-finish Upstream NPU test failures 🟢 Unlikely Gate job failed due to NPU test failures listed above. Log
finish (XPU) Upstream build-and-test failure 🟢 Unlikely Gate job failed due to XPU build-and-test failure listed above. Log

Details

None of these failures are related to this PR's changes. The PR modifies 4 files — all scoped to AMD ROCm fused allreduce+RMSNorm logic and its tests:

  • parallel_state.py: Removes the hidden_dim allowlist and the deterministic-inference 1-stage override from fused_allreduce_rmsnorm, keeping only the byte-size threshold.
  • communicator.py: Fixes a boundary comparison (<<=) in apply_aiter_all_reduce_fusion.
  • benchmark_fused_ar_rms_amd.py: Adds non-allowlisted shapes (e.g., 2880) to benchmark defaults.
  • test_aiter_allreduce_fusion_amd.py: Adds multi-hidden-dim and residual accuracy test cases.

The actual failures fall into three categories:

  1. XPU/Xeon infra issues — Missing pip packages (addict, matplotlib) and KeyError: 'sglang' in DeepSeek custom model code. These are pre-existing environment/model-loading problems on XPU/CPU runners.

  2. NPU infra issues — Missing cached model files and the same KeyError: 'sglang' attention backend issue. Pre-existing NPU runner problems.

  3. Cancelled workflows — The NVIDIA and AMD GPU test stages were all cancelled (likely by a superseding workflow run), causing their finish gates to fail. No actual test ran and failed.

Verdict: No action needed from this PR. All failures are pre-existing infrastructure issues or cancelled runs.

Generated by amd-bot using Claude Code CLI

@HaiShaw HaiShaw merged commit edaa597 into sgl-project:main Apr 12, 2026
28 of 44 checks passed
pyc96 pushed a commit to pyc96/sglang that referenced this pull request Apr 14, 2026
@hubertlu-tw hubertlu-tw changed the title [AMD][No-Merge] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist [AMD] Simplify fused allreduce + RMSNorm and remove hidden_dim allowlist Apr 15, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants