Skip to content

[Core] Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA#34795

Open
grimulkan wants to merge 1 commit into
vllm-project:mainfrom
grimulkan:dcp-fp8-mla
Open

[Core] Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA#34795
grimulkan wants to merge 1 commit into
vllm-project:mainfrom
grimulkan:dcp-fp8-mla

Conversation

@grimulkan

@grimulkan grimulkan commented Feb 18, 2026

Copy link
Copy Markdown
Contributor

Previously, MLA attention blocked the combination of FP8 KV cache (kv_cache_dtype=fp8) with DCP > 1 via hard asserts. This patch enables the combination by:

  • Restructuring the decode Q path to allgather in BF16, then optionally quantize to FP8 post-gather for backends with supports_quant_query_input
  • Replacing cp_gather_cache (dtype-strict) with gather_and_maybe_dequant_cache for FP8 KV cache in the prefill DCP gather path
  • Passing k_scale through to the DCP prefill path (was hardcoded None)
  • Adding a clear guard for the unsupported use_fp8_prefill + DCP > 1 case
  • Adding FP8 DCP test parameterization to test_context_parallel.py

Purpose

MLA attention previously blocked the combination of FP8 KV cache (kv_cache_dtype=fp8) with Decode Context Parallel (DCP) > 1 via two hard asserts:

  • Decode path: assert not fp8_attention, "DCP not support fp8 kvcache now."
  • Prefill DCP gather: assert k_scale is None, "DCP not support scaled kvcache now."

This meant users had to choose between FP8 KV cache (memory savings) and DCP (more memory savings, latency reduction). This PR enables both to work together, maintaining numerical correctness through storage-only FP8 (no new FP8 compute paths).

Changes

mla_attention.py:

  1. Decode Q restructure: DCP > 1 always allgathers Q in BF16 first, then optionally quantizes to FP8 post-gather if supports_quant_query_input=True. This avoids the type mismatch where _DecodeConcatQuantFP8 produces a single FP8 tensor incompatible with the DCP tuple->cat->allgather flow. DCP = 1 path is unchanged.

  2. FP8-aware prefill gather: cp_gather_cache has a strict TORCH_CHECK(src.dtype == dst.dtype) that crashes when FP8 cache meets BF16 workspace. For FP8 KV cache (excluding fp8_ds_mla), the code now calls gather_and_maybe_dequant_cache which fuses gather + FP8->BF16 dequantization. Non-FP8 path continues to use cp_gather_cache.

  3. Metadata additions: Added padded_local_token_to_seq and padded_local_chunk_total_token fields to ChunkedContextMetadata, computed in build(), required by gather_and_maybe_dequant_cache.

  4. k_scale passthrough: forward_mha now passes the real k_scale to the DCP prefill path instead of a hardcoded None.

  5. Guard for use_fp8_prefill + DCP > 1: Added a clear assert with actionable error message. This combination would require FP8 workspace allocation (only for sm10x + FlashInfer/TRT-LLM + use_prefill_query_quantization), and not supported.

test_context_parallel.py:

  1. FP8 DCP test parameterization: Added kv_cache_dtype support to CPTestOptions/CPTestSettings.detailed() and added a new test entry for DeepSeek-V2-Lite-Chat with dcp=4, kv_cache_dtype=fp8.

Test Plan

New test:

  • test_context_parallel.py::test_cp_generation with kv_cache_dtype="fp8", dcp_size=4 — GSM8K 256-question 5-shot accuracy eval with DeepSeek-V2-Lite-Chat
  • End-to-end test with lm_eval (GSM8K) with Kimi K2.5 on sm120 using new settings

Regression tests:

  • test_context_parallel.py::test_cp_generation with kv_cache_dtype=auto, dcp_size=4 — existing DCP test
  • test_mla_backends.py::test_backend_correctness — MLA backend unit tests (DCP=1 forward paths)
  • End-to-end test with lm_eval (GSM8K) with Kimi K2.5 on sm120 using existing previously supported settings

Test Results

Environment: 16x GPUs, sm120, TritonMLA backend (the only MLA backend that works on sm120)

Test Config Result
test_cp_generation (regression) tp=4, dcp=4, kv_cache_dtype=auto PASS (GSM8K accuracy ≥ 0.64)
test_cp_generation (new) tp=4, dcp=4, kv_cache_dtype=fp8 PASS (GSM8K accuracy ≥ 0.64)
test_backend_correctness 48 parameterizations, DCP=1 16 pass, 32 fail (pre-existing) — all failures are pre-existing sm120 backend issues: Cutlass MLA (RuntimeError: Error Internal) and FlashInfer MLA (XQA MLA only supports fp8 on SM120). Not related to this PR.

Full lm_eval GSM8K results (Kimi-K2.5, tp=16, TritonMLA, 5-shot):

Config exact_match (flexible) exact_match (strict)
dcp=16, kv_cache_dtype=auto (baseline) 0.9363 ± 0.0067 0.9363 ± 0.0067
dcp=1, kv_cache_dtype=fp8 (baseline) 0.9378 ± 0.0067 0.9371 ± 0.0067
dcp=16, kv_cache_dtype=fp8 (this PR) 0.9371 ± 0.0067 0.9371 ± 0.0067

FP8 + DCP=16 accuracy matches both baselines within error margins.

Known Limitations

  • fp8_ds_mla + DCP > 1: Not supported (different storage format with embedded block scales). Falls through to cp_gather_cache which will dtype-check at runtime.
  • use_fp8_prefill + DCP > 1: Explicitly guarded with assert. Would require FP8 workspace allocation; currently only possible with sm10x + FlashInfer/TRT-LLM + user selecting use_prefill_query_quantization.
  • supports_quant_query_input=True backends: Post-gather FP8 quant path added but not tested on this (sm120) hardware (requires sm90a for FlashMLA/CutlassMLA). The path uses an additional ops.scaled_fp8_quant which is a well-tested primitive, with all other commands being the same as the Triton MLA sm120 path. Risk is minimal.
  • gather_and_maybe_dequant_cache: Baseline vllm has hardcoded head_dim == 576 constraint, currently limiting FP8 DCP to DeepSeek V2/V3/R1 family models. This constraint is duplicated in this PR.

Note

To run this PR on sm120, it also requires that Triton MLA support kv-cache-dtype fp8 from #34597 since that's the only backend that supports it. At this time, that PR is not yet merged, but the two features are independent and can be merged separately.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copilot AI review requested due to automatic review settings February 18, 2026 10:49

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables the use of FP8 KV cache with Decode Context Parallel (DCP) for MLA, which was previously unsupported. The changes are well-structured, including restructuring the decode path to quantize after all-gathering, updating the prefill path to use a new gather_and_maybe_dequant_cache operation, and adding necessary metadata and test coverage. The implementation looks solid. I have one suggestion to add an explicit guard for a known unsupported configuration to improve user experience and error reporting.

Comment thread vllm/model_executor/layers/attention/mla_attention.py

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables the combination of FP8 KV cache (kv_cache_dtype=fp8) with Decode Context Parallel (DCP) > 1 for MLA attention, which was previously blocked by hard assertions. The changes maintain numerical correctness through storage-only FP8 (no new FP8 compute paths) and are motivated by allowing users to benefit from both FP8 memory savings and DCP's latency/memory improvements simultaneously.

Changes:

  • Restructured the decode Q path to allgather in BF16 first, then optionally quantize to FP8 post-gather for backends with supports_quant_query_input=True
  • Replaced cp_gather_cache with gather_and_maybe_dequant_cache for FP8 KV cache in the prefill DCP gather path to handle dtype mismatches
  • Added padded_local_token_to_seq and padded_local_chunk_total_token metadata fields to ChunkedContextMetadata for FP8 DCP support
  • Passed k_scale through to the DCP prefill path instead of hardcoding None
  • Added a guard for the unsupported use_fp8_prefill + DCP > 1 combination
  • Extended test parameterization to include FP8 DCP testing

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
vllm/model_executor/layers/attention/mla_attention.py Restructured decode Q path for DCP+FP8 compatibility, added FP8-aware prefill gather logic, added metadata fields for FP8 DCP support, passed k_scale to DCP path, and guarded against unsupported FP8 prefill with DCP
tests/distributed/test_context_parallel.py Added kv_cache_dtype parameter support and FP8 DCP test configuration for DeepSeek-V2-Lite-Chat

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@grimulkan

grimulkan commented Feb 18, 2026

Copy link
Copy Markdown
Contributor Author

This approach does differ slightly from @LucasWilkinson in that it uses bf16 all-gathers for Q, which is not the best bandwidth efficiency. I can see the problem is because my Triton MLA fp8 PR has supports_quant_query_input = False (and internally uses bf16 attention). But the other backends expect quantized Q. There may be future backends that have supports_quant_query_input = False. My current approach maximizes compatibility, but is less efficient.

A better approach would be to switch on supports_quant_query_input:

if self.impl.dcp_world_size > 1:
    if fp8_attention and self.impl.supports_quant_query_input:
        # FP8 quant first, all_gather in FP8 (half bandwidth)
        mqa_q = self._decode_concat_quant_fp8_op(
            mqa_ql_nope, mqa_q_pe, self._q_scale)
        mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)
    else:
        # BF16 all_gather (TritonMLA and non-FP8 cases)
        mqa_q = torch.cat((mqa_ql_nope, mqa_q_pe), dim=-1)
        mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)

But I would need someone with sm90 or sm100 to help test the supports_quant_query_input = True path in that case.

EDIT: I found a simple way to include the merge without impacting compatibility. I still cannot test the sm90/100 path, but now the only difference between the 2 paths is an additional fp8 quantization, so we should have the best of both worlds with lower risk on the untested path.

@grimulkan

Copy link
Copy Markdown
Contributor Author

Some speed/VRAM benchmarks on sm120.

Kimi K2.5 on RTX 6000 Pro** (native int4 experts, Marlin gemm, Triton MLA)

Cards TP DCP PP KV Cache Total KV Cache Space Generation Speed (@ 0 context)
8 8 8 1 fp8 3M tok 68 tok/s
8 8 1 1 fp8 380K tok 79 tok/s
8 8 8 1 bf16 1.5M tok 67 tok/s
8 8 1 1 bf16 190K tok 78 tok/s
16 16 16 1 fp8 20M tok 43 tok/s
16 16 1 1 fp8 1.25M tok 64 tok/s
16 16 16 1 bf16 10M tok 42 tok/s
16 16 1 1 bf16 638K tok 60 tok/s

The fp8 versions also require #34597 on sm120
Likely some of this would need to be rebased after #33529 is merged (the above results don't have those improvements).

@voipmonitor

Copy link
Copy Markdown
Contributor

confirming that this is working on 8x RTX PRO AMD Turin:

NCCL_P2P_LEVEL=SYS VLLM_LOG_STATS_INTERVAL=1 NCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xml VLLM_TEST_FORCE_FP8_MARLIN=1 VLLM_MARLIN_USE_ATOMIC_ADD=1 VLARLIN_INPUT_DTYPE=fp8 vllm serve moonshotai/Kimi-K2.5 --served-model-name Kimi-K2.5 --trust-remote-code --host 0.0.0.0 --port 5000 --tensor-parallel-size 8 --pipeline-parallel-size 1 --enable-chunked-prefill --enable-prefix-caching --load-format fastsafetensors --tool-call-parser kimi_k2 --enable-auto-tool-choice --reasoning-parser kimi_k2 --async-scheduling --gpu-memory-utilization 0.93 --max-num-batched-tokens 4096 --mm-processor-cache-gb 0 --mm-encoder-tp-mode weights --language-model-only --attention-backend TRITON_MLA --kv-cache-dtype fp8

GPU KV cache size: 449,600 tokens
speed: 79tok/sec

when --decode-context-parallel-size 8 is used (more KV cache):
GPU KV cache size: 3,621,504 tokens

speed: 66tok/sec

Previously, MLA attention blocked the combination of FP8 KV cache
(kv_cache_dtype=fp8) with DCP > 1 via hard asserts. This patch enables
the combination by:

- Restructuring the decode Q path to allgather in BF16, then optionally
  quantize to FP8 post-gather for backends with supports_quant_query_input
- Replacing cp_gather_cache (dtype-strict) with gather_and_maybe_dequant_cache
  for FP8 KV cache in the prefill DCP gather path
- Passing k_scale through to the DCP prefill path (was hardcoded None)
- Adding a clear guard for the unsupported use_fp8_prefill + DCP > 1 case
- Adding FP8 DCP test parameterization to test_context_parallel.py

Signed-off-by: grimulkan <grimulkan@gmail.com>
@grimulkan

Copy link
Copy Markdown
Contributor Author

Rebased, no change in performance or functionality.

alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
The quality regression shows up when MLA prefill query quantization is
combined with an FP8 KV cache.  That combination matters because
use_prefill_query_quantization is inert unless the KV cache dtype is quantized
and the selected MLA prefill backend claims support.  In the K2.5-style config
without kv-cache-dtype=fp8_e4m3 the flag does not change prefill query dtype;
in the K2.6 config it does.

The old MLA prefill FP8 path was not scale-correct once it became active:

* forward_mha and chunked prefill converted q, k, and v with plain .to(fp8).
  That conversion assumes an implicit scale of 1 and ignores the calibrated
  q/k/v scale buffers loaded for FP8 attention.
* The TRT-LLM ragged DeepSeek prefill wrapper always passed
  bmm1_scale=self.scale and bmm2_scale=1.0.  The TRT-LLM/FlashInfer FP8
  attention convention is bmm1_scale=q_scale*k_scale*sm_scale and
  bmm2_scale=v_scale (modulo output scale when present).  Existing generic
  TRT-LLM attention benchmarks and tests use that convention.
* For chunked context, the FP8 path used cp_gather_cache.  That only copies
  cache bytes; the C++ comment on that path explicitly says scaled KV cache is
  not supported there.  Scaled FP8 cache must be gathered through
  gather_and_maybe_dequant_cache so the cached latent KV is descaled before the
  kv_b projection.

Fix this by making MLA FP8 prefill explicit about scale support:

* Restrict backend_supports_prefill_query_quantization to TRTLLM_RAGGED.
  FlashInfer, FlashAttention, and TokenSpeed MLA prefill do not currently
  expose the BMM scale hooks needed for scaled FP8 inputs, so the flag should
  not activate FP8 prefill for those backends.
* Add MLAAttention._scaled_fp8_prefill_input, which uses the existing static
  QuantFP8 op and layer scale buffers instead of plain casts.
* Pass the MLAAttention layer into the MHA-style impl path so the common MLA
  impl can reuse the layer-owned quant op and q/k/v scale tensors.
* Quantize prefill q with _q_scale, expanded k with _k_scale, and projected v
  with _v_scale.  This differs from MLA decode, where BMM2 uses _k_scale
  because decode attends over the latent KV cache rather than the projected
  value tensor.
* Plumb optional q/k/v descale factors through the MLA prefill backend API.
  Non-supporting backends assert that no scaled FP8 inputs were provided.
  TRTLLM_RAGGED maps the descales to bmm1_scale=self.scale*q_scale*k_scale and
  bmm2_scale=v_scale.
* Allocate the non-DCP chunked prefill workspace in model dtype and always use
  gather_and_maybe_dequant_cache for chunked context, including FP8 KV cache.

Correctness evidence:

* The activation condition is covered by unit tests: FP8 prefill query dtype is
  selected only when the cache dtype is FP8, use_prefill_query_quantization is
  set, the device family is SM100, and the MLA prefill backend is TRTLLM_RAGGED.
* The backend gate test proves FlashInfer, FlashAttention, and TokenSpeed no
  longer accidentally activate the scaled-FP8 path.
* The non-TRT prefill backend tests prove those backends fail closed if scaled
  FP8 q/k/v inputs are accidentally passed through the new API.
* The scaled prefill input helper test proves the quant path flattens only the
  leading dimensions, preserves the original shape, makes the quant input
  contiguous, and passes the supplied scale tensor to the static QuantFP8 op.
* The TRTLLM_RAGGED scale helper test checks both the unscaled legacy case and
  the scaled case, including that partial scale plumbing is rejected.
* The scale formulas match existing TRT-LLM FP8 attention tests and benchmarks:
  bmm1 receives q_scale*k_scale*softmax_scale, and bmm2 receives value scale.
* The chunked-context cache gather now uses the only gather path that accepts a
  KV scale and can dequantize scaled FP8 cache before projection.

Validation run locally:

* .venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q
  -> 26 passed, 16 warnings
* pre-commit run ruff-format --files <changed files> -> passed
* pre-commit run ruff-check --files <changed files> -> passed
* .venv/bin/python -m py_compile <changed files> -> passed
* pre-commit run mypy-local --files <changed files> -> passed

Duplicate-work check:

* vllm-project#39841 only changes FP8 cast ordering in chunked prefill; it does not address
  calibrated q/k/v input quantization, TRT-LLM ragged BMM descales, or backend
  gating.
* vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output / merge-state fusion,
  while this change fixes scaled FP8 input correctness.
* vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not the Blackwell
  TRT-LLM ragged MLA prefill path touched here.
* vllm-project#40609 / vllm-project#34795 are DCP FP8-KV efforts.  This change intentionally handles
  non-DCP chunked context and keeps DCP as separate work.

Co-authored-by: OpenAI Codex <codex@openai.com>
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
The quality regression shows up for public MLA model configurations when
prefill query quantization is combined with an FP8 KV cache.  That combination
matters because use_prefill_query_quantization is inert unless the KV cache dtype
is quantized and the selected MLA prefill backend claims support.  A serve
configuration that omits --kv-cache-dtype=fp8_e4m3 therefore does not exercise
this path, while the same model with FP8 KV cache and TRTLLM_RAGGED prefill does.

The old MLA prefill FP8 path was not scale-correct once it became active:

* forward_mha and chunked prefill converted q, k, and v with plain .to(fp8).
  That conversion assumes an implicit scale of 1 and ignores the calibrated
  q/k/v scale buffers loaded for FP8 attention.
* The TRT-LLM ragged DeepSeek prefill wrapper always passed
  bmm1_scale=self.scale and bmm2_scale=1.0.  The TRT-LLM/FlashInfer FP8
  attention convention is bmm1_scale=q_scale*k_scale*sm_scale and
  bmm2_scale=v_scale (modulo output scale when present).  Existing generic
  TRT-LLM attention benchmarks and tests use that convention.
* For chunked context, the FP8 path used cp_gather_cache.  That only copies
  cache bytes; the C++ comment on that path explicitly says scaled KV cache is
  not supported there.  Scaled FP8 cache must be gathered through
  gather_and_maybe_dequant_cache so the cached latent KV is descaled before the
  kv_b projection.

Fix this by making MLA FP8 prefill explicit about scale support:

* Restrict backend_supports_prefill_query_quantization to TRTLLM_RAGGED.
  FlashInfer, FlashAttention, and TokenSpeed MLA prefill do not currently expose
  the BMM scale hooks needed for scaled FP8 inputs, so the flag should not
  activate FP8 prefill for those backends.
* Add MLAAttention._scaled_fp8_prefill_input, which uses the existing static
  QuantFP8 op and layer scale buffers instead of plain casts.
* Pass the MLAAttention layer into the MHA-style impl path so the common MLA impl
  can reuse the layer-owned quant op and q/k/v scale tensors.
* Quantize prefill q with _q_scale, expanded k with _k_scale, and projected v
  with _v_scale.  This differs from MLA decode, where BMM2 uses _k_scale because
  decode attends over the latent KV cache rather than the projected value tensor.
* Plumb optional q/k/v descale factors through the MLA prefill backend API.
  Non-supporting backends assert that no scaled FP8 inputs were provided.
  TRTLLM_RAGGED maps the descales to bmm1_scale=self.scale*q_scale*k_scale and
  bmm2_scale=v_scale.
* Allocate the non-DCP chunked prefill workspace in model dtype and always use
  gather_and_maybe_dequant_cache for chunked context, including FP8 KV cache.

Open-source reproduction sketch:

Use a public Moonshot MLA checkpoint that exercises the Kimi-K2.x MLA path, for
example moonshotai/Kimi-K2.5 or the corresponding Moonshot Kimi-K2.6 Hugging
Face checkpoint when available.  Choose TP/PP sizes that fit the local Blackwell
host; the important repro knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and
use_prefill_query_quantization.

Baseline server, FP8 KV cache without prefill query quantization:

    MODEL_ID=moonshotai/Kimi-K2.5
    vllm serve "$MODEL_ID" \
      --host 0.0.0.0 \
      --port 8000 \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
      --served-model-name kimi-mla-fp8-baseline

Scaled-FP8 prefill server under test:

    MODEL_ID=moonshotai/Kimi-K2.5
    vllm serve "$MODEL_ID" \
      --host 0.0.0.0 \
      --port 8000 \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A quick deterministic quality probe is to ask small math questions with known
answers and temperature 0:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: 17_b=b+7 and 97_b=9b+7, so b+7 divides 56;
the divisors greater than 16 are 28 and 56, giving bases 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with m=n+2, divisibility reduces to m | 39,
so n is 1, 11, or 37.

Before this fix, the FP8-KV + prefill-query-quantized server can produce empty,
garbled, timed-out, or mathematically incorrect responses while the baseline
server remains coherent.  After this fix, the scaled-FP8 prefill server should
match the baseline quality on these probes.  A control run without
--kv-cache-dtype=fp8_e4m3 should also remain unaffected by the flag, because the
FP8 prefill query path is intentionally inactive unless the cache is quantized.

Correctness evidence:

* The activation condition is covered by unit tests: FP8 prefill query dtype is
  selected only when the cache dtype is FP8, use_prefill_query_quantization is
  set, the device family is SM100, and the MLA prefill backend is TRTLLM_RAGGED.
* The backend gate test proves FlashInfer, FlashAttention, and TokenSpeed no
  longer accidentally activate the scaled-FP8 path.
* The non-TRT prefill backend tests prove those backends fail closed if scaled
  FP8 q/k/v inputs are accidentally passed through the new API.
* The scaled prefill input helper test proves the quant path flattens only the
  leading dimensions, preserves the original shape, makes the quant input
  contiguous, and passes the supplied scale tensor to the static QuantFP8 op.
* The TRTLLM_RAGGED scale helper test checks both the unscaled legacy case and
  the scaled case, including that partial scale plumbing is rejected.
* The scale formulas match existing TRT-LLM FP8 attention tests and benchmarks:
  bmm1 receives q_scale*k_scale*softmax_scale, and bmm2 receives value scale.
* The chunked-context cache gather now uses the only gather path that accepts a
  KV scale and can dequantize scaled FP8 cache before projection.

Validation run locally:

* .venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q
  -> 26 passed, 16 warnings
* pre-commit run ruff-format --files <changed files> -> passed
* pre-commit run ruff-check --files <changed files> -> passed
* .venv/bin/python -m py_compile <changed files> -> passed
* pre-commit run mypy-local --files <changed files> -> passed

Duplicate-work check refreshed on 2026-05-15:

* vllm-project#39841 only changes FP8 cast ordering in chunked prefill; it does not address
  calibrated q/k/v input quantization, TRT-LLM ragged BMM descales, or backend
  gating.
* vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output / merge-state fusion,
  while this change fixes scaled FP8 input correctness.
* vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not the Blackwell
  TRT-LLM ragged MLA prefill path touched here.
* vllm-project#40609 / vllm-project#34795 are DCP FP8-KV efforts.  This change intentionally handles
  non-DCP chunked context and keeps DCP as separate work.
* vllm-project#41568 lifts decode Q-prep work out of forward_impl; it is decode performance
  work and does not repair prefill q/k/v input descales.

Co-authored-by: OpenAI Codex <codex@openai.com>
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend supports scaled FP8 inputs. One observed public configuration is Moonshot MLA checkpoints such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.

The old FP8 prefill path was not scale-correct once active: prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received incomplete BMM scales, and non-DCP chunked context gathered FP8 cache bytes through a path that cannot dequantize scaled KV cache before the `kv_b` projection. That can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.

This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API, mapping TRTLLM_RAGGED to `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.

Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.

Related upstream work reviewed:
- vllm-project#39841 adjusts FP8 cast ordering in chunked prefill, but does not address calibrated q/k/v quantization, TRTLLM_RAGGED BMM scales, or backend gating.
- vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output and merge-state fusion, not scaled FP8 input correctness.
- vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not Blackwell TRTLLM_RAGGED MLA prefill.
- vllm-project#40609 and vllm-project#34795 cover DCP FP8-KV work; this change covers the non-DCP chunked-context path.
- vllm-project#41568 is decode Q-prep refactoring and does not repair prefill input descales.
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend supports scaled FP8 inputs. One observed public configuration is Moonshot MLA checkpoints such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.

The old FP8 prefill path was not scale-correct once active: prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received incomplete BMM scales, and non-DCP chunked context gathered FP8 cache bytes through a path that cannot dequantize scaled KV cache before the `kv_b` projection. That can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.

This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API, mapping TRTLLM_RAGGED to `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.

Reproduction sketch:

Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.

Baseline server:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
      --served-model-name kimi-mla-fp8-baseline

Scaled-FP8 prefill server under test:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is to ask a small math question with temperature 0:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.

Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.

Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.

Related upstream work reviewed:
- vllm-project#39841 adjusts FP8 cast ordering in chunked prefill, but does not address calibrated q/k/v quantization, TRTLLM_RAGGED BMM scales, or backend gating.
- vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output and merge-state fusion, not scaled FP8 input correctness.
- vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not Blackwell TRTLLM_RAGGED MLA prefill.
- vllm-project#40609 and vllm-project#34795 cover DCP FP8-KV work; this change covers the non-DCP chunked-context path.
- vllm-project#41568 is decode Q-prep refactoring and does not repair prefill input descales.
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.

The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.

For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.

This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.

Reproduction sketch:

Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.

Baseline server:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
      --served-model-name kimi-mla-fp8-baseline

Scaled-FP8 prefill server under test:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is to ask a small math question with temperature 0:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.

Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.

Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.

Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.

The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.

For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.

This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.

Reproduction sketch:

Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.

Baseline server:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
      --served-model-name kimi-mla-fp8-baseline

Scaled-FP8 prefill server under test:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is to ask a small math question with temperature 0:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.

Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.

Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.

Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.

The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:

    expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
    expected output = softmax(expected scores) @ (v_fp8 * v_scale)

Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.

Make scaled FP8 MLA prefill explicit end to end:

- quantize prefill q, expanded k, and projected v with the layer-owned static
  QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
  prefill and chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
  `bmm1_scale = softmax_scale * q_scale * k_scale` and
  `bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
  `flashinfer-python==0.6.8.post1`, whose
  `BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
  and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
  passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
  exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
  folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
  the BF16 output is algebraically equivalent to dequantizing Q/K before QK
  and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
  the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
  and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
  projection with `gather_and_maybe_dequant_cache`.

This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN remains unsupported instead of silently producing unscaled FP8 math.

Reproduction sketch:

Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.

Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, and chunked-context
cache gather/dequantization.

Local validation:

- `.venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q`
- `uvx ruff check` on the touched FP8 MLA files and tests
- `.venv/bin/python -m py_compile` on the touched FP8 MLA files and tests
- `git diff --check`

Related upstream work reviewed:

- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
  FP8 cast, but it still leaves missing q/k/v descales and backend scale
  plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
  Those operate after the attention inputs; they do not repair incorrectly
  scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
  and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
  behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
  query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
  quantization and prefill backend scale plumbing, not decode Q preparation.
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 15, 2026
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.

The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:

    expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
    expected output = softmax(expected scores) @ (v_fp8 * v_scale)

Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.

Make scaled FP8 MLA prefill explicit end to end:

- quantize prefill q, expanded k, and projected v with the layer-owned static
  QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
  prefill and chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
  `bmm1_scale = softmax_scale * q_scale * k_scale` and
  `bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
  `flashinfer-python==0.6.8.post1`, whose
  `BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
  and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
  passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
  exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
  folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
  the BF16 output is algebraically equivalent to dequantizing Q/K before QK
  and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
  the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
  and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
  projection with `gather_and_maybe_dequant_cache`.

This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN remains unsupported instead of silently producing unscaled FP8 math.

Reproduction sketch:

Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.

Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, and chunked-context
cache gather/dequantization.

Local validation:

- `.venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q`
- `uvx ruff check` on the touched FP8 MLA files and tests
- `.venv/bin/python -m py_compile` on the touched FP8 MLA files and tests
- `git diff --check`

Related upstream work reviewed:

- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
  FP8 cast, but it still leaves missing q/k/v descales and backend scale
  plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
  Those operate after the attention inputs; they do not repair incorrectly
  scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
  and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
  behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
  query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
  quantization and prefill backend scale plumbing, not decode Q preparation.

Co-authored-by: Codex
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 16, 2026
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.

The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:

    expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
    expected output = softmax(expected scores) @ (v_fp8 * v_scale)

Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.

Make scaled FP8 MLA prefill explicit end to end:

- quantize prefill q, expanded k, and projected v with the layer-owned static
  QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
  prefill and non-DCP chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
  `bmm1_scale = softmax_scale * q_scale * k_scale` and
  `bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
  `flashinfer-python==0.6.8.post1`, whose
  `BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
  and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
  passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
  exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
  folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
  the BF16 output is algebraically equivalent to dequantizing Q/K before QK
  and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
  the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
  and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
  projection with `gather_and_maybe_dequant_cache`.

DCP chunked-context MLA prefill still uses a separate distributed gather path
that does not thread q/k/v descales into context attention. Rather than silently
claim scale-correct support there, reject DCP chunked-context when scaled FP8
prefill query quantization is active. Also keep the ROCm/AITER `forward_mha`
override signature compatible with the widened common MHA interface so the
shared `layer` argument does not break fallback paths.

This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN and DCP chunked-context remain unsupported instead of silently
producing unscaled FP8 math.

Reproduction sketch:

Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.

Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, DCP fail-closed
behavior, ROCm/AITER interface compatibility, and chunked-context cache
gather/dequantization.

Related upstream work reviewed:

- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
  FP8 cast, but it still leaves missing q/k/v descales and backend scale
  plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
  Those operate after the attention inputs; they do not repair incorrectly
  scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
  and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
  behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
  query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
  quantization and prefill backend scale plumbing, not decode Q preparation.

Co-authored-by: Codex <codex@openai.com>
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 17, 2026
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.

The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.

For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.

This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.

Reproduction sketch:

Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.

Baseline server:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
      --served-model-name kimi-mla-fp8-baseline

Scaled-FP8 prefill server under test:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is to ask a small math question with temperature 0:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.

Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.

Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.

Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.

Assisted-by: OpenAI Codex
alexeldeib added a commit to alexeldeib/vllm that referenced this pull request May 17, 2026
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.

The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.

For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.

This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.

Reproduction sketch:

Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.

Baseline server:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
      --served-model-name kimi-mla-fp8-baseline

Scaled-FP8 prefill server under test:

    MODEL_ID=moonshotai/Kimi-K2.6
    vllm serve "$MODEL_ID" \
      --tensor-parallel-size 8 \
      --trust-remote-code \
      --max-model-len 32768 \
      --kv-cache-dtype fp8_e4m3 \
      --attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
      --served-model-name kimi-mla-fp8-pqq

A deterministic quality probe is to ask a small math question with temperature 0:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.

A second probe is:

    curl -sS http://localhost:8000/v1/chat/completions \
      -H 'Content-Type: application/json' \
      -d '{
        "model":"kimi-mla-fp8-pqq",
        "messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
        "temperature":0,
        "max_tokens":2048
      }'

The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.

Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.

This revision also fails closed for modes that need separate scale semantics before FP8 prefill query quantization can be correct: decode context parallelism, runtime `calculate_kv_scales`, and the ROCm AITER MLA override whose `forward_mha` signature must match the shared `layer` argument. Those cases now fall back to model-dtype prefill or the parent implementation instead of silently running an incomplete scaled-FP8 path.

Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.

Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.

Assisted-by: OpenAI Codex
@github-actions

github-actions Bot commented Jun 5, 2026

Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions Bot added the stale Over 90 days of inactivity label Jun 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Over 90 days of inactivity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants