Skip to content

[AMD] Add fused all-reduce RMSNorm per-group quant for Qwen3.5 FP8#24651

Open
hubertlu-tw wants to merge 2 commits intosgl-project:mainfrom
hubertlu-tw:fused_ar_rms_per_group_quant
Open

[AMD] Add fused all-reduce RMSNorm per-group quant for Qwen3.5 FP8#24651
hubertlu-tw wants to merge 2 commits intosgl-project:mainfrom
hubertlu-tw:fused_ar_rms_per_group_quant

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Collaborator

Motivation

This PR adds a fused aiter path for the Qwen3.5 FP8 attention/GDN input normalization path on AMD. The baseline path around prepare_attn runs all-reduce, RMSNorm, and per-group activation quant as separate work before the FP8 projection consumes (fp8, scale).

The new path lets LayerCommunicator.prepare_attn request fused all-reduce + RMSNorm + per-group FP8 quant. Standard attention receives (fp8, scale) directly for qkv_proj. GDN receives (bf16, fp8, scale) so in_proj_qkvz can skip its internal activation quant while in_proj_ba still reads the bf16 activation it needs.

The fused-quant path is scoped to ROCm/aiter. The aiter single-kernel dispatch is gated to gfx95 / gfx950-class GPUs through is_gfx95_supported(), and plain --enable-aiter-allreduce-fusion AR+RMSNorm behavior remains separate. If the fused-quant path is unavailable, callers fall back to fused AR+RMSNorm or the existing generic path.

Modifications

  • Added a tensor-parallel API for fused all-reduce + RMSNorm + per-group FP8
    quant:
    • tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group(...)
    • GroupCoordinator.fused_allreduce_rmsnorm_quant_per_group(...)
  • Added layernorm helpers that return tuple activations for FP8 consumers:
    • ((fp8, scale), residual) for standard attention.
    • ((bf16, fp8, scale), residual) for GDN, where bf16 is required by in_proj_ba.
  • Wired LayerCommunicator.prepare_attn to prefer the fused-quant helper when _sglang_needs_allreduce_fusion is set and the layer opts in with enable_fused_ar_quant=True.
  • Wired Qwen3.5 standard attention with keep_bf16=False and GDN with keep_bf16=True.
  • Updated Qwen3.5 GDN input projection to consume the (bf16, fp8, scale) tuple without dequantizing fp8 back to bf16.
  • Added SGLANG_DISABLE_FUSED_AR_QUANT=1 as the operator opt-out for this quantized handoff while keeping the existing AR+RMSNorm fusion eligible.

Fallback order:

  1. aiter single-kernel AR + RMSNorm + per-group FP8 quant, optionally with bf16 side-output.
  2. aiter fused AR + RMSNorm plus separate aiter per-1x128 quant.
  3. Generic caller fallback.

The single-kernel dispatch is intentionally limited to ROCm + aiter + gfx95.

Accuracy Tests

GSM8K was run on Qwen3.5-397B-A17B-FP8 with TP=8 and
--enable-aiter-allreduce-fusion.

Config GSM8K command Accuracy
Baseline python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000 0.946
Fused AR+RMSNorm+per-group quant python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000 0.955

The fused path is above the 0.94 accuracy gate used for this model.

Speed Tests and Profiling

Server command
SGLANG_USE_AITER_UNIFIED_ATTN=1 \
SGLANG_USE_AITER=1 \
SGLANG_AITER_UNIFIED_VERIFY=1 \
python3 -m sglang.launch_server \
  --model-path /data2/Qwen/Qwen3.5-397B-A17B-FP8/ \
  --tp 8 \
  --attention-backend aiter \
  --trust-remote-code \
  --model-loader-extra-config '{"enable_multithread_load": true}' \
  --watchdog-timeout 1200 \
  --mem-fraction-static 0.9 \
  --host 0.0.0.0 \
  --port 9000 \
  --disable-radix-cache \
  --enable-aiter-allreduce-fusion \
  --max-running-requests 128 \
  --page-size 16
Serving benchmark command
python3 benchmark_serving.py \
  --model /data2/Qwen/Qwen3.5-397B-A17B-FP8 \
  --base-url http://0.0.0.0:9000 \
  --backend sglang \
  --dataset-name random \
  --random-input-len 8192 \
  --random-output-len 8192 \
  --random-range-ratio 1.0 \
  --num-prompts 32 \
  --max-concurrency 16 \
  --request-rate inf \
  --ignore-eos \
  --num-warmups 16 \
  --percentile-metrics ttft,tpot,itl,e2el

End-to-end serving result, 8192 input / 8192 output, concurrency 16, 32 prompts,
ignore-eos:

Metric Baseline Fused path Improvement
Total throughput (tok/s) 2943.67 3018.79 +2.55%
Output throughput (tok/s) 1471.83 1509.40 +2.55%
Median TPOT (ms) 10.69 10.42 +2.53%

Improvement formula:

  • Throughput: (candidate - baseline) / baseline * 100.
  • TPOT: (baseline - candidate) / baseline * 100.

Kernel-count profile, per decode pass on Qwen3.5-397B-A17B-FP8 / TP=8 with 60
decoder layers:

Kernel Baseline Fused path Delta
cross_device_reduce_*stage (plain AR) 60 0 -60
add_rmsnorm_quant_kernel (RMSNorm+add) 60 0 -60
dynamic_per_group_scaled_quant_kernel 60 0 -60
allreduce_fusion_kernel_1stage (AR+RMSNorm) 0 0 0
allreduce_fusion_kernel_1stage_per_group 0 60 +60
Comm-side separate quant for GDN keep-bf16 path 45 0 -45
Total kernels saved on prepare_attn -165

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

On ROCm with --enable-aiter-allreduce-fusion, Qwen3.5-FP8 pays two
extra kernel launches per layer around the input layernorm during
decode: a separate dynamic_per_group_scaled_quant and the AR+RMSNorm
that could be fused with the quant. Wire the aiter
fused_allreduce_rmsnorm_quant_per_group kernel into
LayerCommunicator.prepare_attn for the Qwen3.5 standard-attention
and GDN (linear_attention) layers, with a staged fallback chain so
correctness is preserved if the fused kernel is unavailable.

On Qwen3.5-397B-A17B-FP8, MI355X, TP=8, decode 8k/8k @ concurrency
16 this collapses 59 of the 60 per-layer sites to a single kernel
per layer and removes 59 separate per-group quants per decode pass
(total kernel time -234 us). End-to-end throughput +2.55%, median
TPOT -2.53%, GSM8K 0.955 (vs 0.946 baseline, within noise).

Changes are ROCm/aiter-only. GroupCoordinator returns None on
non-HIP, _forward_with_allreduce_fusion_quant_per_group bails on
not _use_aiter, the Qwen3.5 layer gate is ANDed with _use_aiter,
and the LayerCommunicator opt-in flags default to False so other
models are untouched. Opt-out via SGLANG_DISABLE_FUSED_AR_QUANT=1.

* python/sglang/srt/distributed/communication_op.py: new
  tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group op
  with optional emit_bf16 flag.
* python/sglang/srt/distributed/parallel_state.py: new
  GroupCoordinator method with early is_hip() gate, same shape
  eligibility as fused_allreduce_rmsnorm, returns None on any
  exception so callers stage down gracefully.
* python/sglang/srt/layers/layernorm.py: helper with staged
  fallback chain (1-kernel fused -> 2-kernel fused AR+RMS +
  separate quant -> None). keep_bf16=False returns
  ((fp8, scale), residual); keep_bf16=True returns
  ((bf16, fp8, scale), residual). Exposed on RMSNorm and
  GemmaRMSNorm. Caches aiter per-1x128 quant functor at module
  load so the fallback path does not re-import aiter per forward.
* python/sglang/srt/layers/communicator.py: LayerCommunicator
  takes enable_fused_ar_quant / fused_ar_quant_keep_bf16 kwargs.
  prepare_attn tries the quant helper first under
  (self.enable_fused_ar_quant and _use_aiter), falls back to the
  existing forward_with_allreduce_fusion on None.
* python/sglang/srt/models/qwen3_5.py: Qwen3_5LinearDecoderLayer
  enables with keep_bf16=True (GDN needs bf16 for in_proj_ba);
  Qwen3_5AttentionDecoderLayer enables with keep_bf16=False (qkv
  is the sole consumer). Gate helper is lru_cached.
  _forward_input_proj short-circuits on `_use_aiter and isinstance(
  hidden_states, tuple)` into an AMD-only helper, keeping the
  non-AMD control flow textually unchanged.
* benchmark/kernels/all_reduce/benchmark_fused_ar_rms_quant_amd.py:
  new microbenchmark comparing split (3-kernel) vs fused-AR+RMS
  then separate quant (2-kernel) vs fully-fused fp8-only (1-kernel)
  vs fully-fused fp8+bf16 (1-kernel) across Qwen3.5-FP8 decode and
  prefill shapes, with correctness checks.

Depends on the companion aiter change (aiter branch
hubert/fused_ar_rms_per_group_quant commit d68c79d1a) that plumbs
the bf16_output pointer through the per-group kernel family.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added the quant LLM Quantization label May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant