[Aiter][ROCm] RMSNormGated+GroupedQuantFP8 fusion#40710
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces fusion support for RMSNormGated followed by FP8 group quantization on ROCm platforms using the aiter library. Key changes include the registration of a new fused custom operator, the implementation of a MatcherRMSNormGated class, and updates to the RocmAiterRMSNormQuantFusionPass to discover and fuse these patterns. Feedback focuses on critical safety issues regarding the global monkey-patching of the pattern matcher's type handling, which could lead to incorrect matches for other operators. Additionally, improvements were suggested to ensure the gated fusion pattern correctly supports both aiter and decomposed quantization variants and strictly validates the supported group size of 128 to prevent numerical errors.
5c39363 to
2c82404
Compare
2c82404 to
d4f1b17
Compare
31da8cb to
7b6683e
Compare
|
Some cleanup has been done and needs higher level feedback and a ready label to allow more complete testing. |
|
@gshtras Can you add the |
5753895 to
3307453
Compare
|
Hi @tpopp, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
fe5b13f to
141e8c5
Compare
The triton vs CK group quant op selection was added speculatively but the approved PR vllm-project#41825 uses only the aiter (CK) group quant op. Align the pattern matching with that decision. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
141e8c5 to
380c7ce
Compare
Revert the fx_view_to_reshape rename since the function already exists upstream with the underscore prefix. Only apply the ignore_types monkey-patch for the pattern matcher when gated norm patterns are actually registered, avoiding interference with existing per-token and per-tensor fusion patterns. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
This pull request has merge conflicts that must be resolved before it can be |
The gated RMSNorm + group FP8 quant pattern matches when the quant op traces through native code (-quant_fp8) rather than the custom op. Remove ops_in_model_before since the pre-fusion quant op depends on the custom_ops config. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
4ecce31 to
20e4933
Compare
These parameters were unintentionally removed during earlier cleanup. They are needed by the existing non-gated pattern registration logic. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
@tjtanaa Do you mind taking a look?
I have a plan for a vllm_ir related form of rms_norm_gated but would like to do that as a follow up, so it's precisely targeted, and so I can separately clarify some details over how the IR and pattern matching behaves when custom ops are enabled. |
|
I also haven't seen a better way to handle the Shape related data gathering. As far as I can tell, the reliance on constructing the same ops forces us to construct patterns with the exact constants derived, but pointers are welcomed if I'm wrong. |
Pass match_aiter_quant through to super().__init__ instead of creating a separate MatcherQuantFP8. The base class already creates the matcher with the correct quant key. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…-fusion Co-authored-by: Cursor <cursoragent@cursor.com> # Conflicts: # vllm/compilation/passes/fusion/matcher_utils.py # vllm/compilation/passes/fusion/rocm_aiter_fusion.py Signed-off-by: Tres Popp <tres.popp@amd.com>
ProExpertProg
left a comment
There was a problem hiding this comment.
Just a nit, and a follow-up
|
|
||
| # Apply gating before normalization if needed | ||
| if z is not None and not self.norm_before_gate: | ||
| if z is not None and not norm_before_gate: |
There was a problem hiding this comment.
Restore the comments please? Can be done in follow-up
| return result | ||
|
|
||
|
|
||
| class MatcherRMSNormGated(MatcherCustomOp): |
There was a problem hiding this comment.
Can we migrate this to vLLM IR once available (#38798)
There was a problem hiding this comment.
Absolutely. I'll keep an eye on that PR.
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Liuweixiong0118 <lwx34158427@gmail.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Tres Popp <trespopp@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This PR adds a compilation fusion pass (AiterRMSNormGatedFp8GroupQuantPattern) that fuses the decomposed RMSNormGated + reshape + group FP8 quantization sequence into a single AITER Triton kernel call (fused_rms_gated_fp8_group_quant). This pattern appears in GatedDeltaNetAttention layers (e.g., Qwen3-Next) where each attention head's output goes through gated RMS normalization, is reshaped back to the full hidden dimension, and then group-quantized to FP8 before the output projection linear layer.
Results:
a 9us set of 2 kernels can be combined to 4.5us. In the case of Qwen3Next, this can be a 1-3% improvement depending on how small the workload is (concurrency 1 vs 128).
Motivation
In models using GatedDeltaNetAttention (such as Qwen3-Next-80B-A3B-Instruct-FP8), the output path of each attention block performs:
These three operations decompose into many elementwise and reduction kernels when torch.compile lowers them. By matching this pattern in the FX graph and replacing it with a single fused Triton kernel from AITER, we eliminate multiple GPU kernel launches and intermediate memory traffic.
Changes
• Register rocm_aiter_fused_rms_gated_fp8_group_quant custom op wrapping aiter.ops.triton.quant.fused_rms_gated_fp8_group_quant
• Add rocm_aiter_ops.are_gdn_triton_kernels_available() — checks whether the required AITER Triton kernels (causal_conv1d_update_single_token, gated_delta_net) are importable, allowing graceful fallback on older AITER builds that lack the GDN kernels
• rocm_aiter_fusion.py: Add AiterRMSNormGatedFp8GroupQuantPattern that matches the decomposed norm→reshape→quant graph and replaces it with the fused op. Add _fold_consecutive_reshapes pre-processing pass (needed because make_fx faithfully
records chained reshapes that must be folded for the pattern to match). Dynamically infer num_heads/head_dim from GatedDeltaNetAttention layers via static_forward_context. Gate the pattern on are_gdn_triton_kernels_available()
• matcher_utils.py: Add MatcherRMSNormGated pattern tracer that traces RMSNormGated.forward_static for use in pm.register_replacement. Extend MatcherQuantFP8 to support Triton-based quant op matching
• layernorm.py: Extract RMSNormGated.forward_static as a @staticmethod so both forward_native and the matcher can share the same pure-PyTorch implementation. forward_native delegates to it
• test_fusion.py: Add unit tests (TestGatedModel) for the fusion pattern covering positive match cases (aiter quant, non-aiter quant, per-token dynamic) and negative cases (wrong group shape, per-tensor quant)
AITER Dependency
The fused Triton kernel (fused_rms_gated_fp8_group_quant) is provided by ROCm/aiter#2423 (https://github.com/ROCm/aiter/pull/2423) ("[Triton] optimized decode kernels for Qwen3-Next model"). The fusion pass is gated behind rocm_aiter_ops.are_gdn_triton_kernels_available(), so it is a no-op on AITER versions that do not include this PR.
Benchmark Results
Setup:
• Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8, TP=1
• GPU: AMD MI355x (gfx950), single GPU
• Base image: vllm/vllm-openai-rocm:nightly (vLLM v0.19.2rc1) with AITER rebuilt from aiter:main + PR #2423
• Attention backend: ROCM_AITER_FA
• Compilation: cudagraph_mode=FULL_AND_PIECEWISE, custom_ops=["-rms_norm", "-silu_and_mul", "+quant_fp8"], pass_config={"fuse_norm_quant": true}
• Benchmark command: vllm bench serve --dataset_name random --random_input_len 1024 --random_output_len 1024 --max_concurrency 4 --num_prompts 32 --num_warmups 4 --seed 1 --temperature 0 --ignore_eos
Pattern matching verification:
• With fusion: RocmAiterRMSNormQuantFusionPass replaced 5 patterns (1+2+2 across repeated-layer subgraphs — the 4 additional matches are from AiterRMSNormGatedFp8GroupQuantPattern)
• Without fusion (pattern commented out): replaced 1 pattern (only the existing non-gated AiterRMSNormDynamicQuantPattern)
Throughput (ISL=1024, OSL=1024, concurrency=4):
┌─────────────────────────────────┬─────────────┬──────────┬───────┐
│ Metric │ With Fusion │ Baseline │ Delta │
├─────────────────────────────────┼─────────────┼──────────┼───────┤
│ Output token throughput (tok/s) │ 467.05 │ 456.52 │ +2.3% │
│ Total token throughput (tok/s) │ 934.11 │ 913.04 │ +2.3% │
│ Mean TPOT (ms) │ 8.44 │ 8.66 │ −2.5% │
│ P99 TPOT (ms) │ 8.67 │ 8.98 │ −3.5% │
│ Mean E2EL (ms) │ 8,769 │ 8,971 │ −2.3% │
└─────────────────────────────────┴─────────────┴──────────┴───────┘
Accuracy (lm_eval, gsm8k, 5-shot):
┌──────────────────┬────────────────┬────────────────┬─────────────────────────────┐
│ Filter │ With Fusion │ Baseline │ Delta │
├──────────────────┼────────────────┼────────────────┼─────────────────────────────┤
│ flexible-extract │ 0.8605 ±0.0095 │ 0.8506 ±0.0098 │ +0.0099 (within error bars) │
│ strict-match │ 0.8089 ±0.0108 │ 0.8097 ±0.0108 │ −0.0008 (within error bars) │
└──────────────────┴────────────────┴────────────────┴─────────────────────────────┘
Accuracy is statistically identical — the fusion is numerically safe.
Test plan
• [x] Unit tests: pytest tests/compile/passes/test_fusion.py -k "gated" — positive and negative pattern match cases
• [x] lm_eval --tasks gsm8k --num_fewshot 5 — accuracy unchanged vs. baseline
• [x] vllm bench serve — throughput improved ~2.3%, TPOT improved ~2.5%
• [x] Verified graceful no-op when AITER lacks GDN kernels (are_gdn_triton_kernels_available() == False)