Skip to content

[Performance] Fuse RoPE + KV cache update for MLA backends#35879

Closed
ElizaWszola wants to merge 34 commits into
vllm-project:mainfrom
neuralmagic:fuse-mla-rope-kv-update
Closed

[Performance] Fuse RoPE + KV cache update for MLA backends#35879
ElizaWszola wants to merge 34 commits into
vllm-project:mainfrom
neuralmagic:fuse-mla-rope-kv-update

Conversation

@ElizaWszola

@ElizaWszola ElizaWszola commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

Follow-up to #32335. Fuse RoPE and MLA KV-cache update into concat_and_cache_mla_rope_fused kernel.

Eval (B200)

lm-eval --model vllm --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite", "compilation_config": {"use_inductor_graph_partition": "True", "custom_ops": ["+rotary_embedding"]}}' --tasks gsm8k --batch_size auto

with fusion:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.3768 ± 0.0133
strict-match 5 exact_match 0.3738 ± 0.0133

no fusion:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.3813 ± 0.0134
strict-match 5 exact_match 0.3783 ± 0.0134

Perf (H100, RoPE FlashInfer)

Obtained with

compilation_config=CompilationConfig(
            use_inductor_graph_partition=True,
            custom_ops=["+rotary_embedding"],
)
input-len prefix-len output-len rr main_ttft pr_ttft main_tpot pr_tpot
256 128 128 1 24.97 24.82 4.90 4.88
512 256 256 1 30.35 30.80 5.71 5.72
1024 512 512 1 88.95 88.55 8.54 8.33
256 128 128 2.5 23.70 23.93 5.99 6.01
512 256 256 2.5 29.02 30.08 8.28 7.91
1024 512 512 2.5 40.18 38.21 11.58 11.13
256 128 128 5 27.62 27.75 7.95 7.59
512 256 256 5 35.11 34.49 10.38 10.07
1024 512 512 5 43.57 41.87 12.79 12.57
256 128 128 10 35.59 35.14 10.43 10.41
512 256 256 10 50.82 46.71 14.30 14.07
1024 512 512 10 57.86 51.05 16.17 15.32
256 128 128 25 43.89 42.95 13.89 13.52
512 256 256 25 85.78 74.02 23.70 21.88
1024 512 512 25 111.77 90.61 40.70 35.09
256 128 128 50 54.36 52.29 19.72 18.91
512 256 256 50 139.89 117.83 47.51 41.59
1024 512 512 50 183.21 156.82 54.22 45.50

Perf (H100, Deepseek Scaling RoPE)

Obtained with

compilation_config=CompilationConfig(
            use_inductor_graph_partition=True,
)
input-len prefix-len output-len rr main_ttft pr_ttft main_tpot pr_tpot
256 128 128 1 23.71 26.19 4.83 4.91
512 256 256 1 29.06 43.61 5.75 5.77
1024 512 512 1 87.51 85.85 8.33 8.75
256 128 128 2.5 22.86 22.89 5.98 5.83
512 256 256 2.5 27.37 28.61 7.76 7.76
1024 512 512 2.5 37.85 37.74 11.23 10.91
256 128 128 5 26.99 26.44 7.58 7.78
512 256 256 5 33.02 33.76 10.46 10.41
1024 512 512 5 40.59 40.02 12.37 12.28
256 128 128 10 34.86 37.12 10.23 10.72
512 256 256 10 46.38 47.56 14.00 14.04
1024 512 512 10 50.35 50.93 15.22 15.41
256 128 128 25 43.02 45.91 13.84 14.72
512 256 256 25 71.63 79.39 22.51 30.69
1024 512 512 25 98.38 109.80 37.21 48.59
256 128 128 50 53.34 57.97 20.26 26.65
512 256 256 50 118.50 134.55 42.78 45.82
1024 512 512 50 155.14 154.33 46.64 48.71

(The perf results could be improved for Deepseek Scaling RoPE)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify Bot added the v1 label Mar 3, 2026
@mergify

mergify Bot commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 3, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the RoPE and KV cache update logic for MLA backends to improve performance with torch.compile. The changes introduce a new custom op unified_mla_kv_cache_update and move the cache update logic into a do_kv_cache_update method on the attention implementation classes. This is a good approach to manage side effects and dependencies for torch.compile. However, I've found a critical issue where a None value for slot_mapping is not handled, which could lead to a runtime crash. Please see my comments for details and suggested fixes.

Comment on lines +814 to +834
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops

ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The slot_mapping parameter can be None when slot_mapping.get(self.layer_name) in MLAAttention.forward returns None. This would cause an AttributeError: 'NoneType' object has no attribute 'flatten' when slot_mapping.flatten() is called. You should add a check to handle the case where slot_mapping is None.

Also, the type hint for slot_mapping should be updated to torch.Tensor | None to reflect this.

    def do_kv_cache_update(
        self,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor | None,
        kv_cache_dtype: str,
        k_scale: torch.Tensor,
    ) -> None:
        if kv_cache.numel() == 0 or slot_mapping is None:
            return
        from vllm import _custom_ops as ops

        ops.concat_and_cache_mla(
            kv_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            slot_mapping.flatten(),
            kv_cache_dtype=kv_cache_dtype,
            scale=k_scale,
        )

Comment on lines +881 to +901
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops

ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the do_kv_cache_update in MLAAttentionImpl, slot_mapping can be None. This will cause a crash. Please add a None check for slot_mapping and update its type hint.

    def do_kv_cache_update(
        self,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor | None,
        kv_cache_dtype: str,
        k_scale: torch.Tensor,
    ) -> None:
        if kv_cache.numel() == 0 or slot_mapping is None:
            return
        from vllm import _custom_ops as ops

        ops.concat_and_cache_mla(
            kv_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            slot_mapping.flatten(),
            kv_cache_dtype=kv_cache_dtype,
            scale=k_scale,
        )

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify Bot removed the needs-rebase label Mar 11, 2026
…revisit this

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify Bot removed the needs-rebase label Apr 13, 2026
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@rbrugaro-amd

Copy link
Copy Markdown
Contributor

@ElizaWszola , create PR neuralmagic#163 against yours that adds one more fusion to eliminate the rematerialization of k_pe

before (PR35789)

image

after (this PR)

image

Accuracy is passing and currently collecting E2E perf on Kimi-k2-FP4
cc: @Rohan138 @attila-dusnoki-htec

kv_cache_scale: torch.Tensor,
layer_name: LayerNameType,
) -> torch.Tensor:
forward_context = get_forward_context()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion:

from vllm.model_executor.layers.attention.attention import get_attention_context

...
    layer_name = _resolve_layer_name(layer_name)
    attn_metadata, _, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
    if layer_slot_mapping is not None:
        ops.concat_and_cache_mla_rope_fused(
            positions,
            q_pe,
            k_pe,
            kv_c,
            cos_sin_cache,
            is_neox,
            layer_slot_mapping,
            kv_cache,
            kv_cache_dtype,
            kv_cache_scale,
        )
    return torch.empty(0, device=kv_c.device, dtype=kv_c.dtype)

Slightly more concise, plus you can remove the has_slot_mapping check from the kernel itself

)


class KVCacheMLARoPEFusionPattern:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two classes (KVCacheMLARoPEFusionPattern and KVCacheMLARoPEDeepseekScalingFusionPattern) can be merged with a single use_deepseek_scaling: bool arg, pending additional cleanup in #39488:

        if use_deepseek_scaling:
            self.rope_matcher = MatcherDeepseekScalingRotaryEmbedding(
                is_neox=self.is_neox,
                head_size=self.qk_rope_head_dim,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
            )
        else:
            self.rope_matcher = MatcherRotaryEmbedding(  # type: ignore
                is_neox=self.is_neox,
                head_size=self.qk_rope_head_dim,
                num_heads=self.num_heads,
                num_kv_heads=self.num_kv_heads,
                use_flashinfer=self.use_flashinfer,
            )

cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
if is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the

@Rohan138 Rohan138 Apr 24, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bugfix: I believe this is a leftover comment from vLLM v0. in V1 this should always be positions.shape == (T,). This should just be:

        if is_neox_style:
            cos = cos.repeat(1, 2).unsqueeze(-2)
            sin = sin.repeat(1, 2).unsqueeze(-2)

Without this change, adding a unit test with neox_style=True fails with shape errors.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int64_t block_idx = slot_idx / block_size;
const int64_t entry_idx = slot_idx % block_size;

// NOTE: slot_idx can be -1 if the token is padded

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably move this early return up before the Q RoPE, if this is a padding token due to V1 cudagraphs

return [torch.ops.vllm.fused_concat_and_cache_mla_rope.default]


@pytest.mark.parametrize(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ROCm support, can we add:

if current_platform.is_cuda():
    MLA_BACKENDS = [
        AttentionBackendEnum.FLASH_ATTN_MLA,
        AttentionBackendEnum.FLASHINFER_MLA,
    ]
elif is_aiter_found_and_supported():
    MLA_BACKENDS = [
        AttentionBackendEnum.TRITON_MLA,
        AttentionBackendEnum.ROCM_AITER_MLA,
    ]
else:
    MLA_BACKENDS = []


@pytest.mark.parametrize("attn_backend", MLA_BACKENDS)

Also, this test should probably be in tests/compile/passes not tests/kernels/core

"Enabled is only supported with use_deepseek_scaling_rope (and vice versa)"
)

if is_neox and use_deepseek_scaling_rope:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments about merging the two patterns and the neox_style fix in deepseek_scaling_rope.py

fusion_pass = KVCacheMLARoPEFusionPass(vllm_config)
passes = [
NoOpEliminationPass(vllm_config),
SplitCoalescingPass(vllm_config),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need SplitCoalescing + ScatterSplitReplacement here, those are only for MHA RoPE model graphs; we don't see those patterns in MLA

)


class KVCacheMLARoPEDeepseekScalingFusionPattern:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: we could use the new VllmPatternReplacement + VllmPatternMatcherPass here: https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/vllm_inductor_pass.py

Comment thread vllm/config/vllm.py
def enable_rope_kvcache_mla_fusion(cfg: "VllmConfig") -> bool:
"""Enable if use_inductor_graph_partition is enabled."""

return cfg.compilation_config.use_inductor_graph_partition

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion, same as enable_rope_kvcache_fusion:

    return (
        cfg.compilation_config.use_inductor_graph_partition
        or not cfg.compilation_config.splitting_ops_contain_kv_cache_update()
    )

@Rohan138

Rohan138 commented Apr 24, 2026

Copy link
Copy Markdown
Contributor

@ElizaWszola @ProExpertProg I added some comments with the changes from my branch: #40392

Also, one question: Can we use functional torch.aten ops for pattern matching? This PR uses the torch.aten ops for copy, slice_scatter, etc. in the pattern matching, while I was trying to use a pattern like this:

        def _pattern(
            q: torch.Tensor,
            k_pe: torch.Tensor,
            kv_c_normed: torch.Tensor,
            positions: torch.Tensor,
            cos_sin_cache: torch.Tensor,
            k_scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            q = q.view(-1, self.num_heads, self.qk_head_dim)
            k_pe_unsqueezed = k_pe.unsqueeze(1)
            q[..., self.qk_nope_head_dim :], k_pe = self.rope_matcher(
                positions,
                q[..., self.qk_nope_head_dim :],
                k_pe_unsqueezed,
                cos_sin_cache,
            )
            dummy = torch.ops.vllm.unified_mla_kv_cache_update(
                kv_c_normed, k_pe, _ln, self.kv_cache_dtype, k_scale
            )
            return dummy, q, k_pe

but this does not work as is and needs additional complexity/preprocessing on the q_slice_scatter op due to the extra slice+view+copy: https://github.com/neuralmagic/vllm/blob/81f2ce414dd41239296f81db0b92200c92456f94/vllm/compilation/passes/fusion/kv_cache_mla_rope_fusion.py#L330

@ProExpertProg

Copy link
Copy Markdown
Collaborator

@Rohan138 I'm confused about your question: are you saying we shouldn't be using aten ops? Or are you saying regular torch ops don't work? Or both?

@vllm-project vllm-project deleted a comment from aumghelani Apr 25, 2026
@vllm-project vllm-project deleted a comment from aumghelani Apr 25, 2026
@Rohan138

Rohan138 commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

@Rohan138 I'm confused about your question: are you saying we shouldn't be using aten ops? Or are you saying regular torch ops don't work? Or both?

Sorry, missed your comment; but yes to both. I realized why I wasn't sure about matching the entire aten sequence in the pattern here-currently, the copy+slice_scatter is introduced during AOTAutograd functionalization, but the replacement completely removes the functionalized write into q. So the graph after the fusion but before defunctionalization doesn't actually write the q_pe back into the base q tensor at all; which I'm not sure if we're fine with?

I got a slightly different torch pattern to work in #40392, with a custom defunctionalization pass:

Pattern:
...
        def _pattern(
            q_pe: torch.Tensor,
            k_pe: torch.Tensor,
            kv_c_normed: torch.Tensor,
            positions: torch.Tensor,
            cos_sin_cache: torch.Tensor,
            k_scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            k_pe_unsqueezed = k_pe.unsqueeze(1)
            q_pe, k_pe = self.rope_matcher(
                positions, q_pe, k_pe_unsqueezed, cos_sin_cache
            )
            dummy = torch.ops.vllm.unified_mla_kv_cache_update(
                kv_c_normed, k_pe, _ln, self.kv_cache_dtype, k_scale
            )
            return dummy, q_pe, k_pe

        def _replacement(
            q_pe: torch.Tensor,
            k_pe: torch.Tensor,
            kv_c_normed: torch.Tensor,
            positions: torch.Tensor,
            cos_sin_cache: torch.Tensor,
            k_scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            at = auto_functionalized(
                self.FUSED_OP,
                positions=positions,
                q_pe=q_pe,
                k_pe=k_pe,
                kv_c=kv_c_normed,
                cos_sin_cache=cos_sin_cache,
                is_neox=self.is_neox,
                kv_cache_dtype=self.kv_cache_dtype,
                kv_cache_scale=k_scale,
                layer_name=_ln,
            )
            dummy, q_pe, k_pe_squeezed = at
            k_pe = k_pe_squeezed.unsqueeze(1)
            return dummy, q_pe, k_pe
...

Defunctionalization:

...
            elif (
                hasattr(torch.ops.vllm, "fused_rope_unified_mla_kv_cache_update")
                and at_target
                == torch.ops.vllm.fused_rope_unified_mla_kv_cache_update.default
            ):
                # AOTAutograd functionalizes `q[..., nope_dim:] = rope_result` into
                # a sequence of aten ops on q: view+slice+copy+slice_scatter.
                # Since the fused MLA RoPE op mutates q_pe in-place, we can remove
                # the redundant copy and slice_scatter ops during defunctionalization.
                getitem_nodes = self.getitem_users(node)
                q_pe_out = getitem_nodes[1]

                for user in list(q_pe_out.users):
                    if is_func(user, torch.ops.aten.copy.default):
                        copy_temp = user
                slice_temp = user.args[0]
                for user in list(copy_temp.users):
                    if is_func(user, torch.ops.aten.slice_scatter.default):
                        slice_scatter_temp = user
                view_temp = slice_scatter_temp.args[0]

                view_orig = slice_temp.args[0]
                slice_scatter_temp.replace_all_uses_with(view_orig)
                self._remove(slice_scatter_temp)
                self._remove(copy_temp)
                self._remove(slice_temp)
                self._remove(view_temp)
                self._remove(q_pe_out)

                # defunctionalize k_pe manually; self.replace_users_with_mutated_args
                # does not support only replacing specific kwargs
                k_pe_in = node.kwargs["k_pe"]
                k_pe_out = getitem_nodes[2]
                k_pe_out.replace_all_uses_with(k_pe_in)
                self._remove(k_pe_out)

                self.insert_defunctionalized(graph, node)
                self._remove(node)
...

@ProExpertProg

Copy link
Copy Markdown
Collaborator

@Rohan138 I think that's actually more correct, I know defunctionalization is complex for rope with the views and slice scatters so it makes sense it would be here as well.

@ElizaWszola

Copy link
Copy Markdown
Contributor Author

Implemented in #40392

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants