[AMD] Add fused all-reduce RMSNorm per-group quant for Qwen3.5 FP8#24651
Open
hubertlu-tw wants to merge 2 commits intosgl-project:mainfrom
Open
[AMD] Add fused all-reduce RMSNorm per-group quant for Qwen3.5 FP8#24651hubertlu-tw wants to merge 2 commits intosgl-project:mainfrom
hubertlu-tw wants to merge 2 commits intosgl-project:mainfrom
Conversation
On ROCm with --enable-aiter-allreduce-fusion, Qwen3.5-FP8 pays two extra kernel launches per layer around the input layernorm during decode: a separate dynamic_per_group_scaled_quant and the AR+RMSNorm that could be fused with the quant. Wire the aiter fused_allreduce_rmsnorm_quant_per_group kernel into LayerCommunicator.prepare_attn for the Qwen3.5 standard-attention and GDN (linear_attention) layers, with a staged fallback chain so correctness is preserved if the fused kernel is unavailable. On Qwen3.5-397B-A17B-FP8, MI355X, TP=8, decode 8k/8k @ concurrency 16 this collapses 59 of the 60 per-layer sites to a single kernel per layer and removes 59 separate per-group quants per decode pass (total kernel time -234 us). End-to-end throughput +2.55%, median TPOT -2.53%, GSM8K 0.955 (vs 0.946 baseline, within noise). Changes are ROCm/aiter-only. GroupCoordinator returns None on non-HIP, _forward_with_allreduce_fusion_quant_per_group bails on not _use_aiter, the Qwen3.5 layer gate is ANDed with _use_aiter, and the LayerCommunicator opt-in flags default to False so other models are untouched. Opt-out via SGLANG_DISABLE_FUSED_AR_QUANT=1. * python/sglang/srt/distributed/communication_op.py: new tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group op with optional emit_bf16 flag. * python/sglang/srt/distributed/parallel_state.py: new GroupCoordinator method with early is_hip() gate, same shape eligibility as fused_allreduce_rmsnorm, returns None on any exception so callers stage down gracefully. * python/sglang/srt/layers/layernorm.py: helper with staged fallback chain (1-kernel fused -> 2-kernel fused AR+RMS + separate quant -> None). keep_bf16=False returns ((fp8, scale), residual); keep_bf16=True returns ((bf16, fp8, scale), residual). Exposed on RMSNorm and GemmaRMSNorm. Caches aiter per-1x128 quant functor at module load so the fallback path does not re-import aiter per forward. * python/sglang/srt/layers/communicator.py: LayerCommunicator takes enable_fused_ar_quant / fused_ar_quant_keep_bf16 kwargs. prepare_attn tries the quant helper first under (self.enable_fused_ar_quant and _use_aiter), falls back to the existing forward_with_allreduce_fusion on None. * python/sglang/srt/models/qwen3_5.py: Qwen3_5LinearDecoderLayer enables with keep_bf16=True (GDN needs bf16 for in_proj_ba); Qwen3_5AttentionDecoderLayer enables with keep_bf16=False (qkv is the sole consumer). Gate helper is lru_cached. _forward_input_proj short-circuits on `_use_aiter and isinstance( hidden_states, tuple)` into an AMD-only helper, keeping the non-AMD control flow textually unchanged. * benchmark/kernels/all_reduce/benchmark_fused_ar_rms_quant_amd.py: new microbenchmark comparing split (3-kernel) vs fused-AR+RMS then separate quant (2-kernel) vs fully-fused fp8-only (1-kernel) vs fully-fused fp8+bf16 (1-kernel) across Qwen3.5-FP8 decode and prefill shapes, with correctness checks. Depends on the companion aiter change (aiter branch hubert/fused_ar_rms_per_group_quant commit d68c79d1a) that plumbs the bf16_output pointer through the per-group kernel family.
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR adds a fused aiter path for the Qwen3.5 FP8 attention/GDN input normalization path on AMD. The baseline path around
prepare_attnruns all-reduce, RMSNorm, and per-group activation quant as separate work before the FP8 projection consumes(fp8, scale).The new path lets
LayerCommunicator.prepare_attnrequest fused all-reduce + RMSNorm + per-group FP8 quant. Standard attention receives(fp8, scale)directly forqkv_proj. GDN receives(bf16, fp8, scale)soin_proj_qkvzcan skip its internal activation quant whilein_proj_bastill reads the bf16 activation it needs.The fused-quant path is scoped to ROCm/aiter. The aiter single-kernel dispatch is gated to gfx95 / gfx950-class GPUs through
is_gfx95_supported(), and plain--enable-aiter-allreduce-fusionAR+RMSNorm behavior remains separate. If the fused-quant path is unavailable, callers fall back to fused AR+RMSNorm or the existing generic path.Modifications
quant:
tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group(...)GroupCoordinator.fused_allreduce_rmsnorm_quant_per_group(...)((fp8, scale), residual)for standard attention.((bf16, fp8, scale), residual)for GDN, where bf16 is required byin_proj_ba.LayerCommunicator.prepare_attnto prefer the fused-quant helper when_sglang_needs_allreduce_fusionis set and the layer opts in withenable_fused_ar_quant=True.keep_bf16=Falseand GDN withkeep_bf16=True.(bf16, fp8, scale)tuple without dequantizing fp8 back to bf16.SGLANG_DISABLE_FUSED_AR_QUANT=1as the operator opt-out for this quantized handoff while keeping the existing AR+RMSNorm fusion eligible.Fallback order:
The single-kernel dispatch is intentionally limited to ROCm + aiter + gfx95.
Accuracy Tests
GSM8K was run on Qwen3.5-397B-A17B-FP8 with TP=8 and
--enable-aiter-allreduce-fusion.python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000The fused path is above the 0.94 accuracy gate used for this model.
Speed Tests and Profiling
Server command
SGLANG_USE_AITER_UNIFIED_ATTN=1 \ SGLANG_USE_AITER=1 \ SGLANG_AITER_UNIFIED_VERIFY=1 \ python3 -m sglang.launch_server \ --model-path /data2/Qwen/Qwen3.5-397B-A17B-FP8/ \ --tp 8 \ --attention-backend aiter \ --trust-remote-code \ --model-loader-extra-config '{"enable_multithread_load": true}' \ --watchdog-timeout 1200 \ --mem-fraction-static 0.9 \ --host 0.0.0.0 \ --port 9000 \ --disable-radix-cache \ --enable-aiter-allreduce-fusion \ --max-running-requests 128 \ --page-size 16Serving benchmark command
End-to-end serving result, 8192 input / 8192 output, concurrency 16, 32 prompts,
ignore-eos:
Improvement formula:
(candidate - baseline) / baseline * 100.(baseline - candidate) / baseline * 100.Kernel-count profile, per decode pass on Qwen3.5-397B-A17B-FP8 / TP=8 with 60
decoder layers:
cross_device_reduce_*stage(plain AR)add_rmsnorm_quant_kernel(RMSNorm+add)dynamic_per_group_scaled_quant_kernelallreduce_fusion_kernel_1stage(AR+RMSNorm)allreduce_fusion_kernel_1stage_per_groupprepare_attnChecklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci