[torch.compile] Enable attention and allreduce fusion without custom ops enabled#24604
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
b374514 to
4a44829
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
42f2231 to
a8c9181
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
1e9326c to
e3d0c83
Compare
e3d0c83 to
9151d01
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
9151d01 to
da3cb54
Compare
…g utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542) Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
| STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default | ||
| if hasattr(torch.ops._C, "scaled_fp4_quant"): | ||
| STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default |
There was a problem hiding this comment.
Why did this work before and not now? Should we change how this is registered?
There was a problem hiding this comment.
Oh yeah I think the registration might have been fixed.
There was a problem hiding this comment.
I think I'll punt this to a follow-up PR, in general these ops should be cleaned up
| ) | ||
| else: | ||
| scale = torch.empty(1, device=input.device, dtype=torch.float32) | ||
| scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) |
There was a problem hiding this comment.
Why is this needed? AFAIK this tensor is just a scalar to the kernel
There was a problem hiding this comment.
Needed for custom matching to work, (1,1) is still just one element
There was a problem hiding this comment.
Can you elaborate? it's suspicious that it needs this change
There was a problem hiding this comment.
The native implementation returns (1,1) so this just makes them consistent. I don't remember exactly what I was running into
zou3519
left a comment
There was a problem hiding this comment.
My current understanding is that when we pattern match against the torch native implementation of a custom operator, we register a pattern in Inductor using that native implementation. I'm worried that this approach might be fragile. When the torch native implementation is passed through torch.compile, various graph passes can transform it, so by the time it reaches the post-grad phase (where vLLM’s pattern matching currently happens), the structure may look different.
For example, with rms_norm, it seems we’d need to modify the implementation in a non-trivial way to make it pattern match. I don't know if this is an issue in practice, but it suggests that this scheme could unintentionally constrain how custom operators need to be authored — in ways we might not fully understand yet.
It might be more robust to preserve the custom operator as-is (i.e., avoid decomposing it into torch native ops) and then perform pattern matching directly on the custom operator itself. That would make the process less sensitive to internal graph transformations.
I did see that you wanted this in for the release. Was there a specific reason? If we are turning on the allreduce+rmsnorm fusion by default, for example, then could the fusion instead imply "+rmsnorm"?
|
This pull request has merge conflicts that must be resolved before it can be |
|
The reason this is needed is it lets us do fusion without having to enable custom ops (-O.custom_ops=["+quant_fp8"]). Enabling custom ops leads to lost performance, as demonstrated in the PR description. That's because there are 4 quant ops per layer, one per matmul, and I agree this is a somewhat fragile approach. I would be happy to work on a "lowering" approach where we preserve the high-level structure of ops until later. The downside would be that it would require more work (I think), and we might lose access to optimizations that currently happen before our passes . But I think it wouldn't hurt Inductor in general to have a more explicit sense of converting between higher-level and lower-level representations (or we just move where our custom passes happen). We can tie this work into the "autotuning custom op implementations" like done in pytorch/pytorch#164212. |
|
As discussed offline, we are going to proceed by merging this PR. After PTC, we will move our custom op matching passes to |
|
view/slice noop eliminations were upstreamed to PyTorch so I'm wondering if this is sufficient pytorch/pytorch#151095 pytorch/pytorch#151175 |
…hing-2 Signed-off-by: Luka Govedič <lgovedic@redhat.com>
|
@BoyuanFeng wouldn't that run after |
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Purpose
This PR enables matching the torch implementations of custom ops QuantFP8 and RMSNorm. On
main, fusion currently requires enabling custom ops, but they are slower than their torch counterparts, so the benefit of custom fusion passes is reduced.We add a bunch of "matcher util" objects which can be called in patterns and get traced to the same fx nodes as the custom op they correspond to in both enabled and disabled form automatically.
This PR also adds additional debugging utilities and adds E2E fusion tests to verify fusions happen in models end-to-end instead of just in unit tests.
Test Plan
Unit tests, added more fusion E2E tests.
Test Result
Tests all pass
Performance numbers
Below are B200 numbers (with flashinfer) from
vllm bench serveon the following serve command:We test the following regimes with corresponding additional arguments:
none:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":false,"enable_noop":true}none_fusion_attention:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":true,"enable_noop":true}none_fusion_attention_allreduce:-O.custom_ops='["none"]' -O.pass_config={"enable_fi_allreduce_fusion":true,"enable_attn_fusion":true,"enable_noop":true}rms_quant:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":false,"enable_noop":true}rms_quant_fusion_attention:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":false,"enable_attn_fusion":true,"enable_noop":true}rms_quant_fusion_attention_allreduce:-O.custom_ops='["none", "+quant_fp8", "+rms_norm"]' -O.pass_config={"enable_fi_allreduce_fusion":true,"enable_attn_fusion":true,"enable_noop":true}2 (
none_fusion_attention) and 3 (none_fusion_attention_allreduce) are newly possible with this PR. On main, results are similar except those two are worse as fusion cannot happen without custom ops enabled.redhatai/meta-llama-3.1-70B-Instruct-FP8 (TP=1):Past QPS=10 the server is overloaded so the latency spikes and becomes much more variable. Also note that allreduce fusion is a noop for tp=1.
📊 TTFT Median (ms)
📊 TPOT Median (ms)
📊 ITL Median (ms)
redhatai/meta-llama-3.1-70B-Instruct-FP8 (TP=4):Note that allreduce fusion reduces TPOT at low QP but increases it at high QPS and increases TTFT across the board, this will be addressed in #24248 and #24252.
📊 TTFT Median (ms)
📊 TPOT Median (ms)
📊 ITL Median (ms)