Skip to content

[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds#24248

Merged
ProExpertProg merged 107 commits intovllm-project:mainfrom
neuralmagic:imarkov/fused_allreduce_torch_native
Nov 10, 2025
Merged

[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds#24248
ProExpertProg merged 107 commits intovllm-project:mainfrom
neuralmagic:imarkov/fused_allreduce_torch_native

Conversation

@ilmarkov
Copy link
Copy Markdown
Contributor

@ilmarkov ilmarkov commented Sep 4, 2025

First part of improvement on fused allreduce.

Purpose

Add tunings of thresholds for Flashinfer allreduce fusion.

Adds a benchmark for allreduce fusion to determine input size thresholds for flashinfer allreduce.
Updates thresholds for flashinfer allreduce (as well as adding two shot algorithm usage when it has better performance) on Hopper and Blackwell devices

Moves allreduce out of moe_forward custom op in order to be able to match for fusion for moe models.

Test Plan

Added tests for non custom ops fusion
Added e2e test for Qwen3 MoE

Based on #24604
Second part: #24252 Introduce compile ranges

@mergify
Copy link
Copy Markdown

mergify bot commented Sep 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant enhancement to the all-reduce fusion capabilities, adding support for matching native PyTorch operations in addition to custom ops. This greatly improves usability and performance flexibility. The introduction of a comprehensive benchmark for tuning fusion thresholds is also a valuable addition. The changes are extensive, particularly with the large number of new fusion patterns in vllm/compilation/collective_fusion.py. While the overall approach is sound, I've identified several critical issues in the implementation of these new patterns. Specifically, the return values from some pattern and replacement functions appear to be incorrect, which could lead to fusion failures or incorrect model outputs. I've provided detailed comments and suggestions for these issues. The configuration updates and the new benchmark script are well-implemented and welcome improvements.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return values from the replacement function are incorrect. The pattern returns (rms_output, allreduce_output), which correspond to the normalized output and the all-reduced tensor. The replacement function should return the same structure.

auto_functionalized(flashinfer_trtllm_fused_allreduce_norm, ...) returns a tuple of 5 mutated arguments: (allreduce_in, residual, norm_out, quant_out, scale_out).

The rms_result corresponds to norm_out, which is allreduce[2].
The allreduce_in (which is input to the replacement function) corresponds to allreduce[0].

Therefore, the return statement should be return allreduce[2], allreduce[0].

The current code returns allreduce[3], allreduce[1], which corresponds to (quant_out, residual). This is incorrect and will lead to fusion failures or wrong results.

Suggested change
return allreduce[3], allreduce[1]
return allreduce[2], allreduce[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return values from the replacement function are incorrect. The pattern returns (rms_output, rms_residual), which are the normalized output and the residual output. The replacement function should return the same structure.

When norm_out=None is passed to flashinfer_trtllm_fused_allreduce_norm, the allreduce_in tensor is used as the output buffer for the normalization result and is mutated. auto_functionalized will return a tuple where the first element (allreduce[0]) is the mutated allreduce_in (i.e., norm_out), and the second element (allreduce[1]) is the mutated residual.

Therefore, the correct return should be return allreduce[0], allreduce[1].

The current code returns allreduce[1], allreduce[2], which corresponds to (residual, norm_out). Since norm_out is None in the call, this is incorrect.

Suggested change
return allreduce[1], allreduce[2]
return allreduce[0], allreduce[1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why is the threshold still so low for TP8? I think AR+Norm should have pretty good perf up to some larger message sizes for TP8?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it 1MB for TP8?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvpohanh Here are the results for TP=8 Blackwell with torch symm mem (VLLM_ALLREDUCE_USE_SYMM_MEM=1) enabled (see the set of results below). I used the best performant alternative to fused allreduce. Probably, we can condition on it checking if symm mem is available and enabled, it will overcomplicate the configuration, in my opinion. Compared default allreduce flashinfer fused alternative is not significantly better in 4-16MB region (see results in the end)

Symm mem enabled

World Size: 8
Hidden Dimension: 8192
Warmup Iterations: 5
Benchmark Trials: 20
Quantization Mode: none


Configuration: seq_len=32, dtype=bfloat16, no residual

Input Size: 0.50 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.029 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.012 2.39x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.086 0.34x

Configuration: seq_len=64, dtype=bfloat16, no residual

Input Size: 1.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.030 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.018 1.62x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.056 0.54x

Configuration: seq_len=128, dtype=bfloat16, no residual

Input Size: 2.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.023 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.024 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.033 0.71x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.052 0.45x

Configuration: seq_len=256, dtype=bfloat16, no residual

Input Size: 4.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.031 0.97x
Standard Allreduce Rmsnorm Native Compiled 0.030 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.064 0.47x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.050 0.60x

Configuration: seq_len=256, dtype=bfloat16, no residual

Input Size: 4.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.031 0.97x
Standard Allreduce Rmsnorm Native Compiled 0.030 baseline
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.049 0.61x

Configuration: seq_len=512, dtype=bfloat16, no residual

Input Size: 8.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.044 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.043 baseline
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.297 0.15x

Configuration: seq_len=1024, dtype=bfloat16, no residual

Input Size: 16.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.071 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.077 0.93x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.109 0.66x

Configuration: seq_len=2048, dtype=bfloat16, no residual

Input Size: 32.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.135 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.143 0.94x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.205 0.66x

Default allreduce

Configuration: seq_len=32, dtype=bfloat16, no residual

Input Size: 0.50 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.029 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.012 2.44x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.087 0.34x

Configuration: seq_len=64, dtype=bfloat16, no residual

Input Size: 1.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.030 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 1.00x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.019 1.63x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.056 0.54x

Configuration: seq_len=128, dtype=bfloat16, no residual

Input Size: 2.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.032 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.032 1.00x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.033 0.97x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.052 0.62x

Configuration: seq_len=256, dtype=bfloat16, no residual

Input Size: 4.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.051 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.050 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.064 0.77x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.050 1.00x

Configuration: seq_len=512, dtype=bfloat16, no residual

Input Size: 8.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.079 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.081 0.97x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.068 1.17x

Configuration: seq_len=1024, dtype=bfloat16, no residual

Input Size: 16.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.119 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.125 0.95x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.109 1.09x

Configuration: seq_len=2048, dtype=bfloat16, no residual

Input Size: 32.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.195 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.211 0.93x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.204 0.96x

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilmarkov Is VLLM_ALLREDUCE_USE_SYMM_MEM=1 something that normal vLLM users would set by default? If it's good for performance, why can't we enable it by default? Does it require special environment or special builds? cc @ProExpertProg

@nvjullin Could you check if @ilmarkov 's measurements above match our understanding? Also, could you try if VLLM_ALLREDUCE_USE_SYMM_MEM=1 works in our case? Thanks!

Copy link
Copy Markdown
Contributor Author

@ilmarkov ilmarkov Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it can be enabled by default. There is a PR for it. It works on Hopper and Blackwell.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! we will try both your PRs and run some experiments on our side.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilmarkov Just to clarify: the PyTorch SYMM_MEM implementation does not support AR+Norm fusion, right? So only the AR part uses SYMM_MEM while Norm part is based on native PyT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, symm mem is only for allreduce part, Norm and quantization parts are in native pytorch.

@nvpohanh
Copy link
Copy Markdown
Contributor

nvpohanh commented Sep 5, 2025

cc @nvjullin @elvischenv for vis

@ilmarkov ilmarkov force-pushed the imarkov/fused_allreduce_torch_native branch from e808818 to 61ebc95 Compare September 8, 2025 12:02
@mergify mergify bot removed the needs-rebase label Sep 8, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
@nvpohanh
Copy link
Copy Markdown
Contributor

nvpohanh commented Oct 9, 2025

Hi @ilmarkov , is there any progress and ETA for this change? Thanks!

@ilmarkov ilmarkov marked this pull request as draft October 9, 2025 15:03
@ilmarkov
Copy link
Copy Markdown
Contributor Author

ilmarkov commented Oct 9, 2025

Hi, @nvpohanh . @ProExpertProg works on general custom op matching in #24604. So we will apply allreduce related pattern matching after his PR is landed. I mark current PR as draft for now.

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a few minor notes and then we can merge!


@staticmethod
def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]:
from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilmarkov if this is still an issue to unblock we can just move this computation into the collective_fusion.py file. We can always move it back here later. I am also concerned that the head process (which shouldn't initialize CUDA) might initialize CUDA when querying device capability during device config (not sure if that happens in the head or just the workers).

But if stuff is working feel free to leave it as is

Signed-off-by: ilmarkov <markovilya197@gmail.com>
@ilmarkov
Copy link
Copy Markdown
Contributor Author

We need to move dispatch and combine back under custom op as they conflict with torch.compile. In this PR we need only move (main experts) allreduce outside of custom op.

Signed-off-by: ilmarkov <markovilya197@gmail.com>
@ilmarkov
Copy link
Copy Markdown
Contributor Author

ilmarkov commented Nov 10, 2025

Validation.

vllm serve deepseek-ai/DeepSeek-V2-Lite --disable-log-requests --no-enable-prefix-caching -tp ${tp} -dp ${dp} --max-num-seqs 256 ${enable_expert_parallel} --port 8000 --compilation-config '{"pass_config":{"enable_fusion":false,"enable_attn_fusion":false,"enable_noop":true,"enable_sequen
ce_parallelism":false,"enable_async_tp":false,"enable_fi_allreduce_fusion":true,"fi_allreduce_fusion_max_size_mb":1}}'

TP=4

(model=deepseek-ai/DeepSeek-V2-Lite,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3,tokenized_requests=False), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3950|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.3927|±  |0.0135|

(DP+EP)=4

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3806|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.3783|±  |0.0134|

(TP+EP)=4

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3798|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.3776|±  |0.0134|

Signed-off-by: ilmarkov <markovilya197@gmail.com>
@ProExpertProg ProExpertProg merged commit d17ecc6 into vllm-project:main Nov 10, 2025
54 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in torch.compile integration Nov 10, 2025
@nvpohanh
Copy link
Copy Markdown
Contributor

nvpohanh commented Nov 11, 2025

@ilmarkov Thanks for merging this! Just want to check: have we re-enabled symmetric memory communication by default? It was disabled in #26925 . Thanks!

Update: never mind. I just saw: #27671

Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks reasonable to me

)
if use_flashinfer:

if num_tokens <= max_token_num:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilmarkov @ProExpertProg

@laithsakka and I ran over this a bit more offline and we're a bit worried this line might cause unintended specialization. Do you have instructions on how to trigger this line of code? (and if so, are you able to provide a tlparse of it? We want to check the symbolic shape constraints in the tlparse to see if this introduced anything negative)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilmarkov could you take a look? tlparse instructions here: https://docs.vllm.ai/en/latest/design/debug_vllm_compile

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, @zou3519 how do you want us to send you tlparse results, in an archive?

Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

archive would work, maybe need a better way to share these...

Copy link
Copy Markdown
Contributor Author

@ilmarkov ilmarkov Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the tlparse results. But I am not sure if you will see the specialization in this line of code given that it is in custom op flashinfer_trtllm_fused_allreduce_norm
tl_out_vllm_fi_ar.tar.gz

khluu pushed a commit that referenced this pull request Nov 17, 2025
… thresholds (#24248)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
(cherry picked from commit d17ecc6)
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
… thresholds (vllm-project#24248)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
@ilmarkov ilmarkov deleted the imarkov/fused_allreduce_torch_native branch December 15, 2025 13:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed torch.compile

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants