[Performance][MLA] Lift decode Q-prep (q-absorb + cat + FP8 quant) out of forward_impl#41568
[Performance][MLA] Lift decode Q-prep (q-absorb + cat + FP8 quant) out of forward_impl#41568xaguilar-amd wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the MLAAttention decode query preparation by lifting the Q-absorption BMM, head-dimension concatenation, and FP8 quantization out of the core implementation and into the forward pass as discrete FX nodes. This change is intended to facilitate downstream graph-level optimizations. A potential performance and memory issue was identified where the new preparation logic processes the entire input tensor (including prefill tokens) rather than just the decode tokens, which could lead to significant memory overhead or OOM errors in large-context scenarios.
| def _maybe_prepare_decode_mqa_q( | ||
| self, | ||
| q: torch.Tensor, | ||
| layer_name_encoded: LayerNameType, | ||
| ) -> torch.Tensor | None: | ||
| """Run the lifted decode q-prep (q-absorb + cat + FP8 quant). | ||
|
|
||
| Returns the FP8-quantized ``mqa_q`` (shape ``(B, N, kv_lora_rank + | ||
| qk_rope_head_dim)``) when the trace-time gate is enabled, otherwise | ||
| ``None`` and ``forward_impl`` runs the legacy in-place path. | ||
|
|
||
| We compute the prep over the full input ``q`` (decode + prefill). | ||
| ``forward_impl`` slices it to ``[:num_mqa_tokens]`` for the decode | ||
| kernel; prefill rows continue to flow through ``forward_mha`` with | ||
| the original BF16 ``q``. This wastes a small amount of compute on | ||
| mixed/pure-prefill batches but keeps the prep trace-time-constant | ||
| and graph-visible, which is the whole point. | ||
| """ | ||
| if not self._lift_q_decode_quant: | ||
| return None | ||
|
|
||
| q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) | ||
|
|
||
| # Discrete FX node #1: q-absorption BMM (BF16 output). | ||
| ql_nope = torch.ops.vllm.unified_mla_q_absorb(q_nope, layer_name_encoded) | ||
|
|
||
| # Discrete FX node #2: concat along head_dim. | ||
| q_full = torch.cat((ql_nope, q_pe), dim=-1) | ||
|
|
||
| # Discrete FX node #3: static per-tensor FP8 quantization. | ||
| # Reshape mirrors ``_DecodeConcatQuantFP8.forward`` so QuantFP8 sees | ||
| # the same 2-D shape it has always been called with. | ||
| q_flat = q_full.reshape(q_full.shape[0], -1) | ||
| mqa_q_flat, _ = self._quant_fp8_op(q_flat, self._q_scale) | ||
| return mqa_q_flat.view(q_full.shape) |
There was a problem hiding this comment.
The current implementation of _maybe_prepare_decode_mqa_q processes the entire input tensor q, which includes both decode and prefill tokens (as well as padding in CUDA graph scenarios). For models with large context windows (e.g., DeepSeek-V3), a single prefill can contain tens of thousands of tokens. Running the Q-absorption BMM, concatenation, and FP8 quantization on the full token extent can lead to significant memory waste and potential Out-Of-Memory (OOM) errors. For instance, with 32k tokens and typical MLA dimensions, the intermediate ql_nope tensor alone could consume several gigabytes of GPU memory. Since these results are discarded for prefill tokens inside forward_impl, this is a significant overhead. Consider slicing q to only include the decode tokens (num_decode_tokens) before performing these operations, provided that attn_metadata is available in the forward context.
There was a problem hiding this comment.
Good catch!
New commit adds a vllm::mla_decode_q_take(q, layer_name) -> Tensor custom op at the very top of _maybe_prepare_decode_mqa_q. It reads num_decode_tokens from forward_context.attn_metadata at execution time and slices q to the first num_decode_tokens rows, so every downstream op in the lifted block (unified_mla_q_absorb, aten.cat, static_scaled_fp8_quant) now operates on the decode-only slice. As a side benefit, the prepared_mqa_q[:num_mqa_tokens] re-slice inside forward_impl collapses into an assert since the tensor arrives pre-trimmed.
A few details worth flagging for the review:
Fake-impl shape: mla_decode_q_take_fake returns a tensor sized to q.shape[0] (backed SymInt, used as an upper bound) rather than calling torch.library.get_ctx().new_dynamic_size(). Using an unbacked SymInt here propagates into QuantFP8.forward_native and breaks Dynamo with a data-dependent guard error; this matches what unified_mla_q_absorb_fake already does.
Why not get_attention_context(): the op deliberately reaches forward_context.attn_metadata directly instead of going through get_attention_context(). The latter requires an nn.Module with a populated kv_cache, which the unit-test fixture for the wiring tier doesn't provide; the direct lookup is also strictly cheaper at runtime since we only need num_decode_tokens.
CUDA-graph padding: num_decode_tokens is the right slice point because it already excludes pad rows in the CUDA-graph case (the scheduler computes it from real decode requests, not the padded extent), which is exactly the case you mentioned.
|
This looks good to me, now. |
|
Don't review yet, I am improving something. |
prep on AITER. This PR is a pure refactor: it lifts the q-absorption BMM, the q_nope/q_pe concat, and the static FP8 quant out of ``MLAAttention.forward_impl`` and into ``MLAAttention.forward()`` so ``torch.compile`` sees them as discrete FX nodes instead of an opaque custom op body. No new kernels, no perf change expected -- bit-exact to the legacy in-place path on every gate-on configuration. A trace-time gate (``_lift_q_decode_quant``) keeps the lift off for configurations that aren't ready (TritonMLA on FP8, sparse MLA, ``q_pad_num_heads`` set, DCP > 1, BF16 KV cache, fp8_ds_mla). Those paths run the unchanged legacy block. The added ``prepared_mqa_q`` kwarg on ``unified_mla_attention_with_output`` is additive (trailing default ``None``) and is filtered out by ``_TargetArgsExpr._match`` in the existing ``MLAAttnQuantFusionPass``, so no current pattern matchers break. Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
Address review feedback: the lifted q-prep block was running the q-absorption BMM, the q_nope/q_pe concat, and the static FP8 quant on the *full* token extent (decode + prefill + CUDA-graph padding). For large prefills (DeepSeek-V3-class, 32k+ tokens) the discarded ql_nope/cat/fp8 intermediates wasted multiple GiB of HBM per forward. Add a new vllm::mla_decode_q_take(q, layer_name) -> Tensor custom op that runs at the very top of _maybe_prepare_decode_mqa_q and slices q to the first num_decode_tokens rows (read from forward_context.attn_metadata at execution time). All downstream ops in the lifted block (unified_mla_q_absorb, aten.cat, static_scaled_fp8_quant) now operate on the decode-only slice, so forward_impl's prepared_mqa_q[:num_mqa_tokens] re-slice becomes an assertion instead of a real op. The fake impl returns a tensor sized to q.shape[0] (a backed SymInt upper bound) to keep Dynamo from emitting data-dependent guards on unbacked SymInts, matching unified_mla_q_absorb_fake's pattern. Tests: - Tier 1 bit-exact parity: 2/2 PASS (atol=0) - Tier 2 gate truth-table: 8/8 PASS - Tier 3 wiring (incl. mla_decode_q_take + assert): 6/6 PASS - Tier 4 FX visibility (new node + topology): 1/1 PASS - Tier 5 MLA backends (test_mla_backends.py incl. DSR1 fp8 KV): 124/124 PASS - Pure-decode latency A/B (Kimi-K2.5-MXFP4, TP=4, fp8 KV): median neutral, p90/p99 -20%/-28% (tail improves; legacy paid for the wasted prefill-row work). - Accuracy, no degradation Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
…on-turnansky) and vllm-project#40392 (Rohan138): - Replace VLLM_MLA_LIFT_DECODE_Q_PREP env var with the cc.pass_config.lift_mla_decode_q_prep flag (default False), matching the cc.pass_config.fuse_rope_kvcache_cat_mla shape from vllm-project#40392 and composing cleanly with cc.use_inductor_graph_partition. - Refactor MLAAttention.forward() into a wrap_if_exposed-style outer/inner split: _opaque_forward (default, single vllm::unified_mla_attention_with_output op) and _lifted_inner_forward (lifted path), mirroring ProExpertProg's inner-forward sketch on vllm-project#39346. - Rename the slice op from vllm::mla_decode_q_take to vllm::mla_split_batch and make it SymInt-returning. Same name and signature as vllm-project#39346, so the two PRs converge cleanly. - Drop the vllm::unified_mla_q_absorb wrapper; the q-absorb BMM is now a plain aten::bmm node. Drop bmm(out=) from the plain BMM path: preallocating a new_empty buffer with an unbacked SymInt batch dim added a data-dependent shape-equality guard Dynamo cannot discharge. - Replace the row-mismatch warning + silent recompute in forward_impl with a strict AssertionError that names the cudagraph capture/replay SymInt freeze and points at vllm-project#39346 / vllm-project#41839 for the fix. Fail fast instead of producing wrong outputs that look right. - custom_op.py: short-circuit dynamic_arg_dims wrappers under torch.compiler.is_compiling() so mark_dynamic doesn't fire inside enclosing fullgraph traces (required for the lifted static FP8 quant to trace cleanly through QuantFP8). - Add MLA_FX_DUMP=<path> to the FX-shape unit test to dump the captured pre-pass FX graph for offline review. - envs.py: drop a stray orphan comment fragment that was sitting between two unrelated env vars in upstream. Phase 1 is performance-neutral by construction (gate default-off, opaque path bit-exact with upstream). Phase 2 will bring the actual perf wins when matching the nodes into fused_qk_rope_concat_and_cache_mla. Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
6c7a98f to
9838932
Compare
…lity The lifted decode q-prep chain in MLAAttention runs through QuantFP8.forward_native, which calls group_broadcast with a target shape carrying an unbacked SymInt (the decode-row count returned by vllm::mla_split_batch). The original ordering `t_dim_size != s and t_dim_size != 1` evaluates the int-vs-SymInt comparison first and raises GuardOnDataDependentSymNode under fullgraph capture, even though the immediately-following branch is the standard PyTorch extent-1 broadcasting case. Swap to `t_dim_size != 1 and t_dim_size != s` so the static-1 case short-circuits the AND before any SymInt comparison happens. Bit-equivalent for eager; required for tests/kernels/core/test_mla_q_quant_separation_fx.py to compile. Signed-off-by: Xavier Aguilar <xavier.aguilarfruto@amd.com>
|
Refactored the code addressing some issues (detailed in the commit message, and added here as well for clarity):
|
| Larger batch sizes e.g. during prefill will use the unfused kernels. | ||
| """ | ||
|
|
||
| lift_mla_decode_q_prep: bool = False |
There was a problem hiding this comment.
This shouldn't be controlled by a flag. We can lift this out by default if we're confident it works correctly.
| ) | ||
|
|
||
| @staticmethod | ||
| def _compute_lift_q_decode_quant( |
There was a problem hiding this comment.
Do we need all of these conditions? I get why the Inductor graph partition might be necessary
| @@ -0,0 +1,128 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
There was a problem hiding this comment.
Please consolidate and minimze the two UTs
| _encode_layer_name(self.layer_name), | ||
| torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, encoded) | ||
|
|
||
| # ProExpertProg-style outer/inner-forward split (see PR #39346 |
There was a problem hiding this comment.
nit: minimze/delete comment
|
|
||
| @functools.wraps(fn) | ||
| def wrapper(*args, **kwargs): | ||
| # If this wrapper is reached while tracing an enclosing |
There was a problem hiding this comment.
What's the context/reason for this change?
| # If tensor has fewer dimensions than target shape, treat missing | ||
| # dimensions as size 1 (standard PyTorch broadcasting behavior) | ||
| t_dim_size = t.shape[i] if i < t.ndim else 1 | ||
| if t_dim_size != s and t_dim_size != 1: |
The quality regression shows up for public MLA model configurations when
prefill query quantization is combined with an FP8 KV cache. That combination
matters because use_prefill_query_quantization is inert unless the KV cache dtype
is quantized and the selected MLA prefill backend claims support. A serve
configuration that omits --kv-cache-dtype=fp8_e4m3 therefore does not exercise
this path, while the same model with FP8 KV cache and TRTLLM_RAGGED prefill does.
The old MLA prefill FP8 path was not scale-correct once it became active:
* forward_mha and chunked prefill converted q, k, and v with plain .to(fp8).
That conversion assumes an implicit scale of 1 and ignores the calibrated
q/k/v scale buffers loaded for FP8 attention.
* The TRT-LLM ragged DeepSeek prefill wrapper always passed
bmm1_scale=self.scale and bmm2_scale=1.0. The TRT-LLM/FlashInfer FP8
attention convention is bmm1_scale=q_scale*k_scale*sm_scale and
bmm2_scale=v_scale (modulo output scale when present). Existing generic
TRT-LLM attention benchmarks and tests use that convention.
* For chunked context, the FP8 path used cp_gather_cache. That only copies
cache bytes; the C++ comment on that path explicitly says scaled KV cache is
not supported there. Scaled FP8 cache must be gathered through
gather_and_maybe_dequant_cache so the cached latent KV is descaled before the
kv_b projection.
Fix this by making MLA FP8 prefill explicit about scale support:
* Restrict backend_supports_prefill_query_quantization to TRTLLM_RAGGED.
FlashInfer, FlashAttention, and TokenSpeed MLA prefill do not currently expose
the BMM scale hooks needed for scaled FP8 inputs, so the flag should not
activate FP8 prefill for those backends.
* Add MLAAttention._scaled_fp8_prefill_input, which uses the existing static
QuantFP8 op and layer scale buffers instead of plain casts.
* Pass the MLAAttention layer into the MHA-style impl path so the common MLA impl
can reuse the layer-owned quant op and q/k/v scale tensors.
* Quantize prefill q with _q_scale, expanded k with _k_scale, and projected v
with _v_scale. This differs from MLA decode, where BMM2 uses _k_scale because
decode attends over the latent KV cache rather than the projected value tensor.
* Plumb optional q/k/v descale factors through the MLA prefill backend API.
Non-supporting backends assert that no scaled FP8 inputs were provided.
TRTLLM_RAGGED maps the descales to bmm1_scale=self.scale*q_scale*k_scale and
bmm2_scale=v_scale.
* Allocate the non-DCP chunked prefill workspace in model dtype and always use
gather_and_maybe_dequant_cache for chunked context, including FP8 KV cache.
Open-source reproduction sketch:
Use a public Moonshot MLA checkpoint that exercises the Kimi-K2.x MLA path, for
example moonshotai/Kimi-K2.5 or the corresponding Moonshot Kimi-K2.6 Hugging
Face checkpoint when available. Choose TP/PP sizes that fit the local Blackwell
host; the important repro knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and
use_prefill_query_quantization.
Baseline server, FP8 KV cache without prefill query quantization:
MODEL_ID=moonshotai/Kimi-K2.5
vllm serve "$MODEL_ID" \
--host 0.0.0.0 \
--port 8000 \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.5
vllm serve "$MODEL_ID" \
--host 0.0.0.0 \
--port 8000 \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A quick deterministic quality probe is to ask small math questions with known
answers and temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: 17_b=b+7 and 97_b=9b+7, so b+7 divides 56;
the divisors greater than 16 are 28 and 56, giving bases 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with m=n+2, divisibility reduces to m | 39,
so n is 1, 11, or 37.
Before this fix, the FP8-KV + prefill-query-quantized server can produce empty,
garbled, timed-out, or mathematically incorrect responses while the baseline
server remains coherent. After this fix, the scaled-FP8 prefill server should
match the baseline quality on these probes. A control run without
--kv-cache-dtype=fp8_e4m3 should also remain unaffected by the flag, because the
FP8 prefill query path is intentionally inactive unless the cache is quantized.
Correctness evidence:
* The activation condition is covered by unit tests: FP8 prefill query dtype is
selected only when the cache dtype is FP8, use_prefill_query_quantization is
set, the device family is SM100, and the MLA prefill backend is TRTLLM_RAGGED.
* The backend gate test proves FlashInfer, FlashAttention, and TokenSpeed no
longer accidentally activate the scaled-FP8 path.
* The non-TRT prefill backend tests prove those backends fail closed if scaled
FP8 q/k/v inputs are accidentally passed through the new API.
* The scaled prefill input helper test proves the quant path flattens only the
leading dimensions, preserves the original shape, makes the quant input
contiguous, and passes the supplied scale tensor to the static QuantFP8 op.
* The TRTLLM_RAGGED scale helper test checks both the unscaled legacy case and
the scaled case, including that partial scale plumbing is rejected.
* The scale formulas match existing TRT-LLM FP8 attention tests and benchmarks:
bmm1 receives q_scale*k_scale*softmax_scale, and bmm2 receives value scale.
* The chunked-context cache gather now uses the only gather path that accepts a
KV scale and can dequantize scaled FP8 cache before projection.
Validation run locally:
* .venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q
-> 26 passed, 16 warnings
* pre-commit run ruff-format --files <changed files> -> passed
* pre-commit run ruff-check --files <changed files> -> passed
* .venv/bin/python -m py_compile <changed files> -> passed
* pre-commit run mypy-local --files <changed files> -> passed
Duplicate-work check refreshed on 2026-05-15:
* vllm-project#39841 only changes FP8 cast ordering in chunked prefill; it does not address
calibrated q/k/v input quantization, TRT-LLM ragged BMM descales, or backend
gating.
* vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output / merge-state fusion,
while this change fixes scaled FP8 input correctness.
* vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not the Blackwell
TRT-LLM ragged MLA prefill path touched here.
* vllm-project#40609 / vllm-project#34795 are DCP FP8-KV efforts. This change intentionally handles
non-DCP chunked context and keeps DCP as separate work.
* vllm-project#41568 lifts decode Q-prep work out of forward_impl; it is decode performance
work and does not repair prefill q/k/v input descales.
Co-authored-by: OpenAI Codex <codex@openai.com>
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend supports scaled FP8 inputs. One observed public configuration is Moonshot MLA checkpoints such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`. The old FP8 prefill path was not scale-correct once active: prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received incomplete BMM scales, and non-DCP chunked context gathered FP8 cache bytes through a path that cannot dequantize scaled KV cache before the `kv_b` projection. That can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent. This change makes scaled FP8 MLA prefill explicit: - gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks; - quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts; - plumb q/k/v descale factors through the MLA prefill backend API, mapping TRTLLM_RAGGED to `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; - fail closed in backends that do not support scaled FP8 inputs; - use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection. Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering. Related upstream work reviewed: - vllm-project#39841 adjusts FP8 cast ordering in chunked prefill, but does not address calibrated q/k/v quantization, TRTLLM_RAGGED BMM scales, or backend gating. - vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output and merge-state fusion, not scaled FP8 input correctness. - vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not Blackwell TRTLLM_RAGGED MLA prefill. - vllm-project#40609 and vllm-project#34795 cover DCP FP8-KV work; this change covers the non-DCP chunked-context path. - vllm-project#41568 is decode Q-prep refactoring and does not repair prefill input descales.
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend supports scaled FP8 inputs. One observed public configuration is Moonshot MLA checkpoints such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old FP8 prefill path was not scale-correct once active: prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received incomplete BMM scales, and non-DCP chunked context gathered FP8 cache bytes through a path that cannot dequantize scaled KV cache before the `kv_b` projection. That can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API, mapping TRTLLM_RAGGED to `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 adjusts FP8 cast ordering in chunked prefill, but does not address calibrated q/k/v quantization, TRTLLM_RAGGED BMM scales, or backend gating.
- vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output and merge-state fusion, not scaled FP8 input correctness.
- vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not Blackwell TRTLLM_RAGGED MLA prefill.
- vllm-project#40609 and vllm-project#34795 cover DCP FP8-KV work; this change covers the non-DCP chunked-context path.
- vllm-project#41568 is decode Q-prep refactoring and does not repair prefill input descales.
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.
The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:
expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
expected output = softmax(expected scores) @ (v_fp8 * v_scale)
Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.
Make scaled FP8 MLA prefill explicit end to end:
- quantize prefill q, expanded k, and projected v with the layer-owned static
QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
prefill and chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
`bmm1_scale = softmax_scale * q_scale * k_scale` and
`bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
`flashinfer-python==0.6.8.post1`, whose
`BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
the BF16 output is algebraically equivalent to dequantizing Q/K before QK
and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
projection with `gather_and_maybe_dequant_cache`.
This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN remains unsupported instead of silently producing unscaled FP8 math.
Reproduction sketch:
Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.
Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, and chunked-context
cache gather/dequantization.
Local validation:
- `.venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q`
- `uvx ruff check` on the touched FP8 MLA files and tests
- `.venv/bin/python -m py_compile` on the touched FP8 MLA files and tests
- `git diff --check`
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
FP8 cast, but it still leaves missing q/k/v descales and backend scale
plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
Those operate after the attention inputs; they do not repair incorrectly
scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
quantization and prefill backend scale plumbing, not decode Q preparation.
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.
The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:
expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
expected output = softmax(expected scores) @ (v_fp8 * v_scale)
Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.
Make scaled FP8 MLA prefill explicit end to end:
- quantize prefill q, expanded k, and projected v with the layer-owned static
QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
prefill and chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
`bmm1_scale = softmax_scale * q_scale * k_scale` and
`bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
`flashinfer-python==0.6.8.post1`, whose
`BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
the BF16 output is algebraically equivalent to dequantizing Q/K before QK
and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
projection with `gather_and_maybe_dequant_cache`.
This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN remains unsupported instead of silently producing unscaled FP8 math.
Reproduction sketch:
Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.
Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, and chunked-context
cache gather/dequantization.
Local validation:
- `.venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q`
- `uvx ruff check` on the touched FP8 MLA files and tests
- `.venv/bin/python -m py_compile` on the touched FP8 MLA files and tests
- `git diff --check`
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
FP8 cast, but it still leaves missing q/k/v descales and backend scale
plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
Those operate after the attention inputs; they do not repair incorrectly
scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
quantization and prefill backend scale plumbing, not decode Q preparation.
Co-authored-by: Codex
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.
The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:
expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
expected output = softmax(expected scores) @ (v_fp8 * v_scale)
Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.
Make scaled FP8 MLA prefill explicit end to end:
- quantize prefill q, expanded k, and projected v with the layer-owned static
QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
prefill and non-DCP chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
`bmm1_scale = softmax_scale * q_scale * k_scale` and
`bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
`flashinfer-python==0.6.8.post1`, whose
`BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
the BF16 output is algebraically equivalent to dequantizing Q/K before QK
and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
projection with `gather_and_maybe_dequant_cache`.
DCP chunked-context MLA prefill still uses a separate distributed gather path
that does not thread q/k/v descales into context attention. Rather than silently
claim scale-correct support there, reject DCP chunked-context when scaled FP8
prefill query quantization is active. Also keep the ROCm/AITER `forward_mha`
override signature compatible with the widened common MHA interface so the
shared `layer` argument does not break fallback paths.
This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN and DCP chunked-context remain unsupported instead of silently
producing unscaled FP8 math.
Reproduction sketch:
Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.
Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, DCP fail-closed
behavior, ROCm/AITER interface compatibility, and chunked-context cache
gather/dequantization.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
FP8 cast, but it still leaves missing q/k/v descales and backend scale
plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
Those operate after the attention inputs; they do not repair incorrectly
scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
quantization and prefill backend scale plumbing, not decode Q preparation.
Co-authored-by: Codex <codex@openai.com>
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Assisted-by: OpenAI Codex
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
This revision also fails closed for modes that need separate scale semantics before FP8 prefill query quantization can be correct: decode context parallelism, runtime `calculate_kv_scales`, and the ROCm AITER MLA override whose `forward_mha` signature must match the shared `layer` argument. Those cases now fall back to model-dtype prefill or the parent implementation instead of silently running an incomplete scaled-FP8 path.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Assisted-by: OpenAI Codex
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Phase 1 of a two-phase plan to enable end-to-end fusion of the MLA decode
q-prep chain on AITER. This PR is a pure refactor — it does not change
math, kernels, or the user-visible default path. It lifts the decode q-prep
chain (q-absorb BMM, q_nope/q_pe head-dim concat, static FP8 quant) out of
MLAAttention.forward_impland intoMLAAttention.forward()sotorch.compilesees each step as a discrete FX node instead of an opaquecustom-op body.
This is the FX shape that Phase 2 will pattern-match into the AITER
fused_qk_rope_concat_and_cache_mlakernel. Phase 1 ships only the visibility refactor; no kernel work, no
expected perf change against the opaque path.
Borrows ideas from #39346
(@morrison-turnansky — "Expose MLA to torch.compile") — same
vllm::mla_split_batchSymInt op convention, same outer/inner-forwardsplit idea — and follows the same
cc.pass_config.<flag>shape that#40392 (@Rohan138 —
MLARoPEKVCacheCatFusionPass) settled on for MLA-side toggles.Design
Outer/inner forward split
MLAAttention.forward()now picks between two paths in one place — exactlythe
wrap_if_exposed/inner_forwardshape ProExpertProg sketched in the#39346 review thread:
_opaque_forward()— historical path. Onevllm::unified_mla_attention_with_outputop, no FX-visible q-prep.This is the user-visible default.
_lifted_inner_forward()— runs the prep inline as plain torch ops:vllm::mla_split_batch(SymInt) →aten.narrow→ split → q-absorb BMM→ cat → static FP8 quant. The attention kernel call stays opaque (still
one
vllm::unified_mla_attention_with_outputop), so the backendretains its own correctness/perf boundary; only the prep chain is
exposed to the FX graph.
Single policy site, both branches share the same downstream code, so the
diff a reviewer has to read is one
if/elseinforward().Gating
A new
cc.pass_config.lift_mla_decode_q_prep: bool = Falseflag (mirrorsthe shape of
cc.pass_config.fuse_rope_kvcache_cat_mlafrom #40392)controls the lift. The full predicate, computed once at layer-init time
(
_compute_lift_q_decode_quant), is:The
use_inductor_graph_partitionrequirement matches #39346's recipe formaking the unbacked SymInt safe under piecewise/full cudagraph capture.
The lift is off by default, so the PR has zero impact on users who do
not opt in.
vllm::mla_split_batch(SymInt op)A thin custom op (same name as in #39346 for forward compatibility) reads
forward_context.attn_metadata.num_decode_tokensat execution time andreturns it. Its fake impl returns
ctx.new_dynamic_size(), so downstreamops (
q.narrow(0, 0, num_decode), the inline BMM, cat, FP8 quant) carrythis unbacked SymInt as their leading dim. Wrapping the metadata read in a
custom op keeps it invisible to
torch.compile— surrounding ops stayfully traceable instead of graph-breaking on the metadata access.
Strict mode (no silent fallback)
Earlier revisions of this branch had a row-mismatch warning + recompute
fallback in
forward_impl. That is gone: when the lift gate is on,forward_implassertsprepared_mqa_q.shape[0] == num_mqa_tokenswith an actionable message that names the suspected cause (cudagraph
capture/replay SymInt freeze on the lifted chain) and tells the user how
to fall back (
cc.pass_config.lift_mla_decode_q_prep=False). Rationalematches the @morrison-turnansky / @ProExpertProg "fail fast" stance in
#39346 on size-0 cases: a silent recompute hides cudagraph freeze and
produces wrong outputs that look right; an assert surfaces them on the
first decode step.
Schema / wire compatibility
unified_mla_attention_with_outputkeeps its existingprepared_mqa_qtrailing optional kwarg (default
None). The Inductor pattern matcher(
torch._inductor.pattern_matcher._TargetArgsExpr._match) normalisestarget-side kwargs through
torch.fx.operator_schemas.normalize_functionand filters them down to the keys each pattern lists, so existing
MLAAttnQuantFusionPasspatterns (MLAAttnFp8StaticQuantPattern,MLAAttnNvfp4QuantPattern,MLAAttnFp8GroupQuantPattern) continue tomatch without modification.
Files changed
vllm/config/compilation.pyPassConfig.lift_mla_decode_q_prep: bool = False.vllm/model_executor/layers/attention/mla_attention.pyvllm::mla_split_batchop, strict assert inforward_impl,_run_decode_q_prep_kernelsshared between lifted and legacy paths.tests/kernels/core/test_mla_q_quant_separation_fx.pyMLA_FX_DUMP=<path>env to dump the captured pre-pass graph for offline review.tests/kernels/core/test_mla_q_quant_separation.pyatol=0) between the lifted Python sequence and the legacy in-place block on the same inputs.vllm/model_executor/custom_op.pydynamic_arg_dimswrapper is reached whiletorch.compiler.is_compiling(), skipmark_dynamic(Dynamo forbids it during fullgraph trace) and call the eager-native function directly so the outer trace captures it. Required for the lifted path's static FP8 quant to compile cleanly.vllm/model_executor/layers/quantization/utils/quant_utils.pygroup_broadcast: short-circuit on the extent-1 case (t_dim_size != 1) before the SymInt equality (t_dim_size != s). The lifted FP8 quant chain runs throughQuantFP8.forward_native → group_broadcastwith a target shape carrying the unbacked SymInt fromvllm::mla_split_batch; the original ordering forcedint != SymIntfirst and raisedGuardOnDataDependentSymNodeunder fullgraph capture. Bit-equivalent for eager; required fortest_mla_q_quant_separation_fx.pyto compile.Test plan
E2E server tests and gsm8k 5-shot tests done as well.
Test results
Phase 1 FX-shape test — captured graph
pytest tests/kernels/core/test_mla_q_quant_separation_fx.py -v→ 1 passed.With
MLA_FX_DUMP=…set, the pre-pass FX graph is the contract Phase 2'smatcher binds to:
The required topology (
vllm::mla_split_batch(SymInt) →aten.slice→split →
aten.bmm→aten.cat→ static FP8 quant chain → reshape onu0) is asserted by the unit test; the dump above is just the human-readable version of the same contract.
Bit-exact parity unit test
pytest tests/kernels/core/test_mla_q_quant_separation.py -v→ 2/2 pass,atol=0. The lifted Python sequence is bit-identical to the legacyin-place
BMM + _DecodeConcatQuantFP8block on the same inputs, atseq_len=1andseq_len=7, with canonical Kimi-K2.5 / DeepSeek-V3 dims(
qk_nope_head_dim=128,qk_rope_head_dim=64,kv_lora_rank=512,num_heads=128,bf16).E2E accuracy (
amd/Kimi-K2-Thinking-MXFP4, ROCm AITER, FP8 KV, TP=4)GSM8K 5-shot, same compiled binary, three configurations:
lift_mla_decode_q_prepuse_inductor_graph_partitionprepared_mqa_q row mismatch (16 vs 512 decode rows)) — see "Known limitation"The opaque path (rows 1 and 2) is bit-exact with clean upstream, so the PR
has zero impact on users who do not opt in.
Known limitation: cudagraph SymInt freeze on the lifted chain
Row 3 above is the expected Phase 1 outcome on a stack that does not
yet include #39346's "stable cuda graph piece replay" work (commit
0aa9492). Under piecewise/full cudagraph capture, the unbacked SymIntreturned by
vllm::mla_split_batchis resolved to the capture-timedecode count and baked into the captured graph's tensor sizes / kernel
grids. At replay with a different runtime
num_decode_tokens, thecaptured chain still operates at the captured size — the lifted
prepared_mqa_qarrives with the wrong leading dim and the strictassert in
forward_implfires.This is the cudagraph capture/replay SymInt freeze that PR #39346
sidesteps with
cudagraph_mode=FULL_AND_PIECEWISE+use_inductor_graph_partition=true, and that Phase 2 makesmoot by collapsing the entire chain into a single fused op that reads
the runtime decode size inside the op — at which point the boundary-
crossing tensor that the freeze acts on simply does not exist.
The PR therefore intentionally:
behaviour, bit-exact with clean upstream);
instead of silently recomputing the prep, which earlier revisions
did and which masked the issue);
isolation (no cudagraph) and is independent of the freeze.
End users with stacks that include #39346 — or, eventually, anyone
running with Phase 2 — get the lift; everyone else stays on the opaque
path.
Notes for reviewers
vllm::mla_split_batch,same name and signature as in [Refactor][MLA]: Expose mla to torch.compile #39346.
unified_mla_attention_with_outputkeeps its existingprepared_mqa_qtrailing optional kwarg (defaultNone); existingpositional callers and FX patterns are unaffected.
_run_decode_q_prep_kernelsis the shared helper between thelifted path and the legacy in-impl path — same kernels, same math,
so the lifted path is bit-exact when the gate is on. The plain-bmm
fallback in this helper deliberately does not preallocate
out=:with
q_nope_nbcarrying an unbacked SymInt batch dim, anout=buffer constructed viaq_nope_nb.new_empty((N, B, L))addsa data-dependent shape-equality guard that Dynamo cannot discharge
(
GuardOnDataDependentSymNode: u0). Lettingtorch.bmminfer theoutput shape avoids the guard and is bit-equivalent.
forward_implkeeps theq_pad_num_headsand DCP branchesinline. The lift gate disables itself for both, so the lifted
helper never reaches them.
custom_op.pymark_dynamicguard is a small one-liner thatfixes a fullgraph-trace error on the lifted path. When a
dynamic_arg_dims-wrapped op (here,QuantFP8) is reached whileDynamo is already tracing an enclosing graph,
torch._dynamo.mark_dynamicraisesAttempt to trace forbidden callable mark_dynamic. The wrapper now short-circuits to theeager-native function under
torch.compiler.is_compiling(), lettingthe outer trace capture it normally. Bounded scope, no behaviour
change outside fullgraph capture.
group_broadcastshort-circuit invllm/model_executor/layers/quantization/utils/quant_utils.pyis a one-line reorder of the two equality checks. The lifted FP8
quant chain runs through
QuantFP8.forward_native → group_broadcastwith a target shape carrying the unbacked SymInt from
vllm::mla_split_batch. The original orderingt_dim_size != s and t_dim_size != 1evaluates theint != SymIntcomparison first and trips
GuardOnDataDependentSymNodeunderfullgraph capture, even though the very next branch is the standard
PyTorch extent-1 broadcasting case. Swapping to
t_dim_size != 1 and t_dim_size != sshort-circuits on the static-1case before any SymInt comparison happens. Bit-equivalent for eager;
required for
test_mla_q_quant_separation_fx.pyto compile.Risks & follow-ups
exact with clean upstream).
forward_implis theonly way the gate-on path can produce wrong output, and it does
not — it stops the run with an actionable message.
is the path to actual perf wins. Phase 1 by itself is performance-
neutral by construction (same kernels, same math, only the FX
visibility changes).
vllm::mla_split_batchop name, the same outer/inner-forward split shape, and the same
use_inductor_graph_partitionpre-condition; once [Refactor][MLA]: Expose mla to torch.compile #39346 lands, itsdecode branch can call
_maybe_prepare_decode_mqa_qdirectlywithout any additional refactor.
cc @ProExpertProg @LucasWilkinson @MatthewBonanni @morrison-turnansky
AI assistance disclosure
This change was prepared with AI assistance (Cursor / Claude). I
(Xavier Aguilar) reviewed every changed line, ran every test command
listed above on local hardware, and stand behind the implementation
end-to-end. Duplicate-work checks against #39346 and #40392 were
performed; this PR is materially different — it is a pure refactor of
the decode q-prep chain only, not the prefill/decode split (#39346's
scope) and not the RoPE/KV-cache fusion (#40392's scope). Phase 2
will consume the FX shape this PR establishes.