Skip to content

[Performance][MLA][ROCm] AITER fused QK-RoPE + KV cache + q-absorb + q-cat + q-quant for decode#41839

Draft
xaguilar-amd wants to merge 41 commits into
vllm-project:mainfrom
xaguilar-amd:mla_qk_rope_cache_fusion
Draft

[Performance][MLA][ROCm] AITER fused QK-RoPE + KV cache + q-absorb + q-cat + q-quant for decode#41839
xaguilar-amd wants to merge 41 commits into
vllm-project:mainfrom
xaguilar-amd:mla_qk_rope_cache_fusion

Conversation

@xaguilar-amd

Copy link
Copy Markdown
Contributor

TL;DR

Builds on top of #40392 to additionally fuse the decode-side q path
(q-absorb BMM, q-concat, FP8 q-quant) into AITER's
fused_qk_rope_concat_and_cache_mla kernel — collapsing 4 ops into 1
on every decode step on AMD MI300X / MI355X. Decode-bucket only by
construction; prefill graphs are byte-for-byte identical to #40392.
Disabled by default; opt-in via pass_config.fuse_aiter_qk_rope_kvcache_mla=True.

Purpose

After #40392, the decode hot path on ROCm + AITER looks like:

fused_rope_unified_mla_kv_cache_update(...)   # RoPE + KV cache write   (already fused by #40392)
   ↓
do_decode_q_prep(q)                            # q-absorb BMM + cat + FP8 quant   (NOT fused)
   ↓
unified_mla_attention_with_output(q_prepped, ...)

The q-prep stage is composed of four small ops launched per layer per
decode step (BMM, split, concat, FP8 quant). AITER ships a single
kernel — fused_qk_rope_concat_and_cache_mla — that does all of it,
including the RoPE+KV-cache half that #40392 already fuses. Hooking it up
needs the q-prep ops to be visible to the FX graph (currently they live
inside forward_impl).

This PR:

  1. Lifts q-prep into a custom op (mla_decode_q_prep) above
    unified_mla_attention_with_output, in a new compilation pass
    (MLADecodeQPrepLiftPass).
  2. Folds the pair (fused_rope_unified_mla_kv_cache_update, mla_decode_q_prep) into one AITER call, in a second new pass
    (MLAAiterQkRopeKVCacheFusionPass).
  3. Bounds memory & CUDA-graph safety by gating both passes on
    compile_range.end <= max_num_seqs × (1 + num_speculative_tokens),
    the same formula CudaGraphManager uses to classify decode-mode
    captures.

Net result: one fused decode kernel for RoPE + KV-cache + q-absorb +
q-cat + q-quant, with zero overhead on the prefill / mixed graphs.

Design choices (the parts reviewers will ask about)

1. INVARIANT 1: mla_decode_q_prep_impl does not lie about its shape

A previous attempt at this fusion (closed, by request — was the predecessor of #41568)
declared an mla_decode_q_prep whose fake_impl
shape was q.shape
but whose real impl returned q[:num_decode].
Inductor sized downstream ops to the full T; runtime returned 0 rows
during high-range CUDA-graph warmup; static_per_tensor_quant launched
with grid_dim = T against an empty buffer → null-pointer GPU fault on
the (4682, 16384) compile range. In addition, it also had some design flaws.

The fix: the impl processes every row of
q, never slices on attention metadata
. The fake_impl declares
[q.size(0), num_heads, kv_lora + qk_rope] and the real impl honors it.
The decode-bucket gate (next section) is what makes this allocation
free.

There's an explicit unit test —
test_mla_decode_q_prep_invariant_1 — that asserts
output.size(0) == q.size(0) for T ∈ {1, 16, 64, 256}. There's also a
CUDA-graph capture/replay regression test
(test_mla_aiter_fusion_cuda_graph_capture) that exercises both ends of
the decode bucket end-to-end.

2. Auto-derived decode-bucket threshold

MLADecodeQPrepLiftPass and MLAAiterQkRopeKVCacheFusionPass only fire
for compile ranges with
end <= aiter_qk_rope_kvcache_fusion_max_token_num. The default value
is auto-derived in VllmConfig._set_compile_ranges:

decode_query_len = 1 + num_speculative_tokens
max_token_num = scheduler_config.max_num_seqs * decode_query_len

This is exactly the formula CudaGraphManager._init_candidates already
uses to classify decode-mode CUDA-graph captures. Keeping the pass gate
aligned with that classification eliminates a footgun (you can't tune
one without the other accidentally going stale) and removes the need
for a manual knob in 99% of deployments. An explicit value is still
honored as an override.

3. Building on top of PR #40392's fused-RoPE+KVCache

  • Sequencing: MLARoPEKVCacheCatFusionPass ([Performance][DSR1]: Fused RoPE+KVCache+q_concat for MLA #40392) → MLADecodeQPrepLiftPassMLAAiterQkRopeKVCacheFusionPass. The AITER pass matches the pair (auto_functionalized(fused_rope_unified_mla_kv_cache_update, ...), mla_decode_q_prep) keyed by layer_name and folds them into one call.
  • Auto-enabling [Performance][DSR1]: Fused RoPE+KVCache+q_concat for MLA #40392: turning on fuse_aiter_qk_rope_kvcache_mla auto-enables fuse_rope_kvcache_cat_mla (it's a strict prerequisite). A clear log line is emitted.
  • Cycle-breaking via _unwrap_q_orig: [Performance][DSR1]: Fused RoPE+KVCache+q_concat for MLA #40392 leaves the model's q[..., qk_nope:] = q_pe_rotated write functionalized as slice_scatter(q_orig, copy(slice_dst, getitem(frmkv, 1))). Naively reusing that as q for the new fused node closes a cycle (new_node → slice_scatter → new_q_pe = new_node[1]). We walk back to q_orig (which AITER doesn't need rotated since the kernel does RoPE itself and only consumes q_nope), breaking the cycle. There's a corresponding tweak in FixFunctionalizationPass so view_temp is not erased — it's now a live input to the new fused op.

4. vLLM stores FP8 KV cache as torch.uint8

STR_DTYPE_TO_TORCH_DTYPE["fp8"] -> torch.uint8. AITER's torch→AITER
dtype mapping rejects uint8 for kv_cache (AITER_DTYPE_u8 isn't in the
kernel's whitelist) and crashes at warmup with
[AITER] kv cache data type is not supported. PR #40392's path goes
through vLLM's own _C_cache_ops.concat_and_cache_mla_rope_fused, which
takes an explicit kv_cache_dtype: str and accepts uint8, so the
issue doesn't surface there. We zero-copy-view the kv_cache as
current_platform.fp8_dtype() (float8_e4m3fn on gfx950,
float8_e4m3fnuz on gfx94) before dispatch when
is_quantized_kv_cache(kv_cache_dtype).

Compatibility / no-effect cases

  • Non-ROCm or non-AITER: __post_init__ disables
    fuse_aiter_qk_rope_kvcache_mla with a warning. Pass is never built.
  • Prefill compile ranges: is_applicable_for_range returns False,
    pass is skipped, prefill graphs are unchanged from [Performance][DSR1]: Fused RoPE+KVCache+q_concat for MLA #40392.
  • Default settings: the flag is opt-in. With it off, this PR is a
    pure no-op (modulo refactored q-prep helpers in mla_attention.py,
    which preserve identical behavior).

Rohan138 and others added 30 commits April 20, 2026 14:02
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Rohan138 and others added 11 commits April 29, 2026 15:56
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
…ide q path (q-absorb BMM, q-concat, FP8 q-quant) into AITER's fused_qk_rope_concat_and_cache_mla kernel

Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
@mergify mergify Bot added the rocm Related to AMD ROCm label May 6, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 6, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a series of compilation passes designed to optimize Multi-Head Latent Attention (MLA) on ROCm by leveraging AITER fused kernels. The changes include new passes for lifting query preparation and fusing RoPE, KV cache updates, and query concatenation into single operations. Additionally, the PR adds pattern matching support for DeepSeek-style scaling rotary embeddings and updates the MLAAttention layer to support pre-prepared query tensors. Extensive unit tests and parity checks are included to ensure correctness and CUDA graph stability. Feedback from the reviewer identifies potential safety issues in the defunctionalization logic that could lead to runtime errors and suggests ensuring tensor contiguity after transpositions in the query preparation methods to maintain compatibility with downstream kernels.

I am having trouble creating individual review comments. Click here to see my feedback.

vllm/compilation/passes/utility/fix_functionalization.py (193-218)

high

This block lacks safety checks and initialization for copy_temp, slice_temp, and slice_scatter_temp. If the expected aten.copy.default or aten.slice_scatter.default nodes are not found in the graph, this will raise an UnboundLocalError or AttributeError. Additionally, it assumes that getitem_nodes contains indices 1 and 2 without checking, which could lead to a KeyError. Please follow the safer pattern used in the subsequent elif block (lines 242-269).

                getitem_nodes = self.getitem_users(node)
                if 1 in getitem_nodes:
                    q_pe_out = getitem_nodes[1]
                    copy_temp = None
                    for user in list(q_pe_out.users):
                        if is_func(user, torch.ops.aten.copy.default):
                            copy_temp = user
                            break
                    if copy_temp is not None:
                        slice_temp = copy_temp.args[0]
                        slice_scatter_temp = None
                        for user in list(copy_temp.users):
                            if is_func(user, torch.ops.aten.slice_scatter.default):
                                slice_scatter_temp = user
                                break
                        if slice_scatter_temp is not None:
                            view_temp = slice_scatter_temp.args[0]
                            view_orig = slice_temp.args[0]
                            slice_scatter_temp.replace_all_uses_with(view_orig)
                            self._remove(slice_scatter_temp)
                            self._remove(copy_temp)
                            self._remove(slice_temp)
                            self._remove(view_temp)
                    self._remove(q_pe_out)

                # defunctionalize k_pe manually; self.replace_users_with_mutated_args
                # does not support only replacing specific kwargs
                if 2 in getitem_nodes:
                    k_pe_in = node.kwargs["k_pe"]
                    k_pe_out = getitem_nodes[2]
                    k_pe_out.replace_all_uses_with(k_pe_in)
                    self._remove(k_pe_out)

vllm/model_executor/layers/attention/mla_attention.py (564)

high

The fallback path returns a non-contiguous tensor due to the transpose operation. Since AITER kernels and other downstream operations often expect contiguous inputs for performance and correctness, it is safer to ensure the result is contiguous.

            ql_nope = ql_nope.transpose(0, 1).contiguous()

vllm/model_executor/layers/attention/mla_attention.py (629)

high

The fallback path returns a non-contiguous tensor due to the transpose operation. Since AITER kernels and other downstream operations often expect contiguous inputs for performance and correctness, it is safer to ensure the result is contiguous.

            ql_nope = ql_nope.transpose(0, 1).contiguous()

@mergify

mergify Bot commented May 23, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xaguilar-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants