[Performance] Fuse RoPE + KV cache update for MLA backends#35879
[Performance] Fuse RoPE + KV cache update for MLA backends#35879ElizaWszola wants to merge 34 commits into
Conversation
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request refactors the RoPE and KV cache update logic for MLA backends to improve performance with torch.compile. The changes introduce a new custom op unified_mla_kv_cache_update and move the cache update logic into a do_kv_cache_update method on the attention implementation classes. This is a good approach to manage side effects and dependencies for torch.compile. However, I've found a critical issue where a None value for slot_mapping is not handled, which could lead to a runtime crash. Please see my comments for details and suggested fixes.
| def do_kv_cache_update( | ||
| self, | ||
| kv_c_normed: torch.Tensor, | ||
| k_pe: torch.Tensor, | ||
| kv_cache: torch.Tensor, | ||
| slot_mapping: torch.Tensor, | ||
| kv_cache_dtype: str, | ||
| k_scale: torch.Tensor, | ||
| ) -> None: | ||
| if kv_cache.numel() == 0: | ||
| return | ||
| from vllm import _custom_ops as ops | ||
|
|
||
| ops.concat_and_cache_mla( | ||
| kv_c_normed, | ||
| k_pe.squeeze(1), | ||
| kv_cache, | ||
| slot_mapping.flatten(), | ||
| kv_cache_dtype=kv_cache_dtype, | ||
| scale=k_scale, | ||
| ) |
There was a problem hiding this comment.
The slot_mapping parameter can be None when slot_mapping.get(self.layer_name) in MLAAttention.forward returns None. This would cause an AttributeError: 'NoneType' object has no attribute 'flatten' when slot_mapping.flatten() is called. You should add a check to handle the case where slot_mapping is None.
Also, the type hint for slot_mapping should be updated to torch.Tensor | None to reflect this.
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor | None,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0 or slot_mapping is None:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)| def do_kv_cache_update( | ||
| self, | ||
| kv_c_normed: torch.Tensor, | ||
| k_pe: torch.Tensor, | ||
| kv_cache: torch.Tensor, | ||
| slot_mapping: torch.Tensor, | ||
| kv_cache_dtype: str, | ||
| k_scale: torch.Tensor, | ||
| ) -> None: | ||
| if kv_cache.numel() == 0: | ||
| return | ||
| from vllm import _custom_ops as ops | ||
|
|
||
| ops.concat_and_cache_mla( | ||
| kv_c_normed, | ||
| k_pe.squeeze(1), | ||
| kv_cache, | ||
| slot_mapping.flatten(), | ||
| kv_cache_dtype=kv_cache_dtype, | ||
| scale=k_scale, | ||
| ) |
There was a problem hiding this comment.
Similar to the do_kv_cache_update in MLAAttentionImpl, slot_mapping can be None. This will cause a crash. Please add a None check for slot_mapping and update its type hint.
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor | None,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0 or slot_mapping is None:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
…revisit this Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
|
@ElizaWszola , create PR neuralmagic#163 against yours that adds one more fusion to eliminate the rematerialization of before (PR35789)
after (this PR)
Accuracy is passing and currently collecting E2E perf on Kimi-k2-FP4 |
| kv_cache_scale: torch.Tensor, | ||
| layer_name: LayerNameType, | ||
| ) -> torch.Tensor: | ||
| forward_context = get_forward_context() |
There was a problem hiding this comment.
suggestion:
from vllm.model_executor.layers.attention.attention import get_attention_context
...
layer_name = _resolve_layer_name(layer_name)
attn_metadata, _, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
ops.concat_and_cache_mla_rope_fused(
positions,
q_pe,
k_pe,
kv_c,
cos_sin_cache,
is_neox,
layer_slot_mapping,
kv_cache,
kv_cache_dtype,
kv_cache_scale,
)
return torch.empty(0, device=kv_c.device, dtype=kv_c.dtype)
Slightly more concise, plus you can remove the has_slot_mapping check from the kernel itself
| ) | ||
|
|
||
|
|
||
| class KVCacheMLARoPEFusionPattern: |
There was a problem hiding this comment.
These two classes (KVCacheMLARoPEFusionPattern and KVCacheMLARoPEDeepseekScalingFusionPattern) can be merged with a single use_deepseek_scaling: bool arg, pending additional cleanup in #39488:
if use_deepseek_scaling:
self.rope_matcher = MatcherDeepseekScalingRotaryEmbedding(
is_neox=self.is_neox,
head_size=self.qk_rope_head_dim,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
else:
self.rope_matcher = MatcherRotaryEmbedding( # type: ignore
is_neox=self.is_neox,
head_size=self.qk_rope_head_dim,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
use_flashinfer=self.use_flashinfer,
)
| cos, sin = cos_sin.chunk(2, dim=-1) | ||
| if self.is_neox_style: | ||
| if is_neox_style: | ||
| # NOTE(woosuk): Here we assume that the positions tensor has the |
There was a problem hiding this comment.
bugfix: I believe this is a leftover comment from vLLM v0. in V1 this should always be positions.shape == (T,). This should just be:
if is_neox_style:
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
Without this change, adding a unit test with neox_style=True fails with shape errors.
There was a problem hiding this comment.
| const int64_t block_idx = slot_idx / block_size; | ||
| const int64_t entry_idx = slot_idx % block_size; | ||
|
|
||
| // NOTE: slot_idx can be -1 if the token is padded |
There was a problem hiding this comment.
We can probably move this early return up before the Q RoPE, if this is a padding token due to V1 cudagraphs
| return [torch.ops.vllm.fused_concat_and_cache_mla_rope.default] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( |
There was a problem hiding this comment.
For ROCm support, can we add:
if current_platform.is_cuda():
MLA_BACKENDS = [
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
]
elif is_aiter_found_and_supported():
MLA_BACKENDS = [
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.ROCM_AITER_MLA,
]
else:
MLA_BACKENDS = []
@pytest.mark.parametrize("attn_backend", MLA_BACKENDS)
Also, this test should probably be in tests/compile/passes not tests/kernels/core
| "Enabled is only supported with use_deepseek_scaling_rope (and vice versa)" | ||
| ) | ||
|
|
||
| if is_neox and use_deepseek_scaling_rope: |
There was a problem hiding this comment.
See comments about merging the two patterns and the neox_style fix in deepseek_scaling_rope.py
| fusion_pass = KVCacheMLARoPEFusionPass(vllm_config) | ||
| passes = [ | ||
| NoOpEliminationPass(vllm_config), | ||
| SplitCoalescingPass(vllm_config), |
There was a problem hiding this comment.
don't need SplitCoalescing + ScatterSplitReplacement here, those are only for MHA RoPE model graphs; we don't see those patterns in MLA
| ) | ||
|
|
||
|
|
||
| class KVCacheMLARoPEDeepseekScalingFusionPattern: |
There was a problem hiding this comment.
suggestion: we could use the new VllmPatternReplacement + VllmPatternMatcherPass here: https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/vllm_inductor_pass.py
| def enable_rope_kvcache_mla_fusion(cfg: "VllmConfig") -> bool: | ||
| """Enable if use_inductor_graph_partition is enabled.""" | ||
|
|
||
| return cfg.compilation_config.use_inductor_graph_partition |
There was a problem hiding this comment.
suggestion, same as enable_rope_kvcache_fusion:
return (
cfg.compilation_config.use_inductor_graph_partition
or not cfg.compilation_config.splitting_ops_contain_kv_cache_update()
)
|
@ElizaWszola @ProExpertProg I added some comments with the changes from my branch: #40392 Also, one question: Can we use functional but this does not work as is and needs additional complexity/preprocessing on the |
|
@Rohan138 I'm confused about your question: are you saying we shouldn't be using aten ops? Or are you saying regular torch ops don't work? Or both? |
Sorry, missed your comment; but yes to both. I realized why I wasn't sure about matching the entire aten sequence in the pattern here-currently, the copy+slice_scatter is introduced during AOTAutograd functionalization, but the replacement completely removes the functionalized write into q. So the graph after the fusion but before defunctionalization doesn't actually write the I got a slightly different torch pattern to work in #40392, with a custom defunctionalization pass: Defunctionalization: |
|
@Rohan138 I think that's actually more correct, I know defunctionalization is complex for rope with the views and slice scatters so it makes sense it would be here as well. |
|
Implemented in #40392 |


Follow-up to #32335. Fuse RoPE and MLA KV-cache update into
concat_and_cache_mla_rope_fusedkernel.Eval (B200)
with fusion:
no fusion:
Perf (H100, RoPE FlashInfer)
Obtained with
Perf (H100, Deepseek Scaling RoPE)
Obtained with
(The perf results could be improved for Deepseek Scaling RoPE)