[Performance][MLA][ROCm] AITER fused QK-RoPE + KV cache + q-absorb + q-cat + q-quant for decode#41839
[Performance][MLA][ROCm] AITER fused QK-RoPE + KV cache + q-absorb + q-cat + q-quant for decode#41839xaguilar-amd wants to merge 41 commits into
Conversation
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
…ide q path (q-absorb BMM, q-concat, FP8 q-quant) into AITER's fused_qk_rope_concat_and_cache_mla kernel Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a series of compilation passes designed to optimize Multi-Head Latent Attention (MLA) on ROCm by leveraging AITER fused kernels. The changes include new passes for lifting query preparation and fusing RoPE, KV cache updates, and query concatenation into single operations. Additionally, the PR adds pattern matching support for DeepSeek-style scaling rotary embeddings and updates the MLAAttention layer to support pre-prepared query tensors. Extensive unit tests and parity checks are included to ensure correctness and CUDA graph stability. Feedback from the reviewer identifies potential safety issues in the defunctionalization logic that could lead to runtime errors and suggests ensuring tensor contiguity after transpositions in the query preparation methods to maintain compatibility with downstream kernels.
I am having trouble creating individual review comments. Click here to see my feedback.
vllm/compilation/passes/utility/fix_functionalization.py (193-218)
This block lacks safety checks and initialization for copy_temp, slice_temp, and slice_scatter_temp. If the expected aten.copy.default or aten.slice_scatter.default nodes are not found in the graph, this will raise an UnboundLocalError or AttributeError. Additionally, it assumes that getitem_nodes contains indices 1 and 2 without checking, which could lead to a KeyError. Please follow the safer pattern used in the subsequent elif block (lines 242-269).
getitem_nodes = self.getitem_users(node)
if 1 in getitem_nodes:
q_pe_out = getitem_nodes[1]
copy_temp = None
for user in list(q_pe_out.users):
if is_func(user, torch.ops.aten.copy.default):
copy_temp = user
break
if copy_temp is not None:
slice_temp = copy_temp.args[0]
slice_scatter_temp = None
for user in list(copy_temp.users):
if is_func(user, torch.ops.aten.slice_scatter.default):
slice_scatter_temp = user
break
if slice_scatter_temp is not None:
view_temp = slice_scatter_temp.args[0]
view_orig = slice_temp.args[0]
slice_scatter_temp.replace_all_uses_with(view_orig)
self._remove(slice_scatter_temp)
self._remove(copy_temp)
self._remove(slice_temp)
self._remove(view_temp)
self._remove(q_pe_out)
# defunctionalize k_pe manually; self.replace_users_with_mutated_args
# does not support only replacing specific kwargs
if 2 in getitem_nodes:
k_pe_in = node.kwargs["k_pe"]
k_pe_out = getitem_nodes[2]
k_pe_out.replace_all_uses_with(k_pe_in)
self._remove(k_pe_out)vllm/model_executor/layers/attention/mla_attention.py (564)
The fallback path returns a non-contiguous tensor due to the transpose operation. Since AITER kernels and other downstream operations often expect contiguous inputs for performance and correctness, it is safer to ensure the result is contiguous.
ql_nope = ql_nope.transpose(0, 1).contiguous()
vllm/model_executor/layers/attention/mla_attention.py (629)
The fallback path returns a non-contiguous tensor due to the transpose operation. Since AITER kernels and other downstream operations often expect contiguous inputs for performance and correctness, it is safer to ensure the result is contiguous.
ql_nope = ql_nope.transpose(0, 1).contiguous()
|
This pull request has merge conflicts that must be resolved before it can be |
TL;DR
Builds on top of #40392 to additionally fuse the decode-side q path
(q-absorb BMM, q-concat, FP8 q-quant) into AITER's
fused_qk_rope_concat_and_cache_mlakernel — collapsing 4 ops into 1on every decode step on AMD MI300X / MI355X. Decode-bucket only by
construction; prefill graphs are byte-for-byte identical to #40392.
Disabled by default; opt-in via
pass_config.fuse_aiter_qk_rope_kvcache_mla=True.Purpose
After #40392, the decode hot path on ROCm + AITER looks like:
The q-prep stage is composed of four small ops launched per layer per
decode step (BMM, split, concat, FP8 quant). AITER ships a single
kernel —
fused_qk_rope_concat_and_cache_mla— that does all of it,including the RoPE+KV-cache half that #40392 already fuses. Hooking it up
needs the q-prep ops to be visible to the FX graph (currently they live
inside
forward_impl).This PR:
mla_decode_q_prep) aboveunified_mla_attention_with_output, in a new compilation pass(
MLADecodeQPrepLiftPass).(fused_rope_unified_mla_kv_cache_update, mla_decode_q_prep)into one AITER call, in a second new pass(
MLAAiterQkRopeKVCacheFusionPass).compile_range.end <= max_num_seqs × (1 + num_speculative_tokens),the same formula
CudaGraphManageruses to classify decode-modecaptures.
Net result: one fused decode kernel for RoPE + KV-cache + q-absorb +
q-cat + q-quant, with zero overhead on the prefill / mixed graphs.
Design choices (the parts reviewers will ask about)
1. INVARIANT 1:
mla_decode_q_prep_impldoes not lie about its shapeA previous attempt at this fusion (closed, by request — was the predecessor of #41568)
declared an
mla_decode_q_prepwhose fake_implshape was
q.shapebut whose real impl returnedq[:num_decode].Inductor sized downstream ops to the full
T; runtime returned 0 rowsduring high-range CUDA-graph warmup;
static_per_tensor_quantlaunchedwith
grid_dim = Tagainst an empty buffer → null-pointer GPU fault onthe (4682, 16384) compile range. In addition, it also had some design flaws.
The fix: the impl processes every row of
q, never slices on attention metadata. The fake_impl declares[q.size(0), num_heads, kv_lora + qk_rope]and the real impl honors it.The decode-bucket gate (next section) is what makes this allocation
free.
There's an explicit unit test —
test_mla_decode_q_prep_invariant_1— that assertsoutput.size(0) == q.size(0)forT ∈ {1, 16, 64, 256}. There's also aCUDA-graph capture/replay regression test
(
test_mla_aiter_fusion_cuda_graph_capture) that exercises both ends ofthe decode bucket end-to-end.
2. Auto-derived decode-bucket threshold
MLADecodeQPrepLiftPassandMLAAiterQkRopeKVCacheFusionPassonly firefor compile ranges with
end <= aiter_qk_rope_kvcache_fusion_max_token_num. The default valueis auto-derived in
VllmConfig._set_compile_ranges:This is exactly the formula
CudaGraphManager._init_candidatesalreadyuses to classify decode-mode CUDA-graph captures. Keeping the pass gate
aligned with that classification eliminates a footgun (you can't tune
one without the other accidentally going stale) and removes the need
for a manual knob in 99% of deployments. An explicit value is still
honored as an override.
3. Building on top of PR #40392's fused-RoPE+KVCache
MLARoPEKVCacheCatFusionPass([Performance][DSR1]: Fused RoPE+KVCache+q_concat for MLA #40392) →MLADecodeQPrepLiftPass→MLAAiterQkRopeKVCacheFusionPass. The AITER pass matches the pair(auto_functionalized(fused_rope_unified_mla_kv_cache_update, ...), mla_decode_q_prep)keyed bylayer_nameand folds them into one call.fuse_aiter_qk_rope_kvcache_mlaauto-enablesfuse_rope_kvcache_cat_mla(it's a strict prerequisite). A clear log line is emitted._unwrap_q_orig: [Performance][DSR1]: Fused RoPE+KVCache+q_concat for MLA #40392 leaves the model'sq[..., qk_nope:] = q_pe_rotatedwrite functionalized asslice_scatter(q_orig, copy(slice_dst, getitem(frmkv, 1))). Naively reusing that asqfor the new fused node closes a cycle (new_node → slice_scatter → new_q_pe = new_node[1]). We walk back toq_orig(which AITER doesn't need rotated since the kernel does RoPE itself and only consumesq_nope), breaking the cycle. There's a corresponding tweak inFixFunctionalizationPasssoview_tempis not erased — it's now a live input to the new fused op.4. vLLM stores FP8 KV cache as
torch.uint8STR_DTYPE_TO_TORCH_DTYPE["fp8"] -> torch.uint8. AITER's torch→AITERdtype mapping rejects
uint8for kv_cache (AITER_DTYPE_u8isn't in thekernel's whitelist) and crashes at warmup with
[AITER] kv cache data type is not supported. PR #40392's path goesthrough vLLM's own
_C_cache_ops.concat_and_cache_mla_rope_fused, whichtakes an explicit
kv_cache_dtype: strand acceptsuint8, so theissue doesn't surface there. We zero-copy-
viewthe kv_cache ascurrent_platform.fp8_dtype()(float8_e4m3fnon gfx950,float8_e4m3fnuzon gfx94) before dispatch whenis_quantized_kv_cache(kv_cache_dtype).Compatibility / no-effect cases
__post_init__disablesfuse_aiter_qk_rope_kvcache_mlawith a warning. Pass is never built.is_applicable_for_rangereturnsFalse,pass is skipped, prefill graphs are unchanged from [Performance][DSR1]: Fused RoPE+KVCache+q_concat for MLA #40392.
pure no-op (modulo refactored q-prep helpers in
mla_attention.py,which preserve identical behavior).