[Bugfix] Fix FP8 cast order in _compute_prefill_context for chunked prefill#39841
[Bugfix] Fix FP8 cast order in _compute_prefill_context for chunked prefill#39841qiching wants to merge 1 commit into
_compute_prefill_context for chunked prefill#39841Conversation
…refill Move FP8 quantization of K/V tensors to after _concat_k_nope_k_pe, since flashinfer_concat_mla_k only supports BF16/FP16 inputs. The incorrect cast order caused 100% crash for any workload using chunked prefill + FP8 prefill (use_prefill_query_quantization). Also ensure k_pe (FP8 from workspace) is cast to k_nope's dtype (BF16) before concatenation to avoid dtype mismatch. This aligns _compute_prefill_context with the already-correct forward_mha path. Signed-off-by: Albert Cheng (Engrg-Hardware 1) <albecheng@lyris0001.lyris.clusters.nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request modifies the _compute_prefill_context method in mla_attention.py to adjust the sequence of data type conversions when use_fp8_prefill is active. It ensures k_pe is cast to the same type as k_nope prior to concatenation for kernel compatibility, and then casts the resulting key and value tensors to the target FP8 type. I have no feedback to provide.
pavanimajety
left a comment
There was a problem hiding this comment.
Thanks for the PR!
Seems like we should fix concat_mla_k here from the Flashinfer end. I think we shouldn't concat BF16 tensors when we can get away with concatenating FP8 tensors. Less data movement. That was my original reasoning as well, sorry I missed the > 8192 case. Since the gains are from actual Prefill Quantization, and not from moving the cast itself, I recommend we tackle this from Flashinfer end.
…t prefill performance and refactor type dispatch for BF16/FP16 (#3129) ## Summary Enable FP8 (E4M3/E5M2) support in `concat_mla_k`, fixing a crash that blocks FP8 chunked prefill for all MLA models (DeepSeek-V2/V3/R1) on long-context workloads. ## Motivation For long-context inference (ISL >= 4K) in vLLM, chunked prefill + FP8 quantization (`use_prefill_query_quantization: true`) is critical for reducing TTFT and improve throughput. The FP8 FMHA kernel is ~1.35x faster than BF16 at 128K context, but this path was complete unusable because `concat_mla_k` rejected FP8 inputs. In vLLM's `_compute_prefill_context` (the chunked prefill path for MLA), K/V tensors are cast to FP8 before being passed to `flashinfer_concat_mla_k`: ```python if use_fp8_prefill: kv_nope = kv_nope.to(prefill_metadata.q_data_type) # BF16 to FP8 k_pe = k_pe.to(prefill_metadata.q_data_type) k_nope, v = kv_nope.split(...) k = self._concat_k_nope_k_pe(k_nope, k_pe) # ← crash: FP8 not supported ``` The kernel uses `DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16` which only dispatches BF16/FP16, causing RuntimeError A vLLM-side workaround (my [PR #39841](vllm-project/vllm#39841) reordering cast after concat) works, but it introduces an extra BF16 to FP8 round-trip and does not address the root cause. The proper fix is enabling FP8 at the kernel level. We keep it for temporal workaround. ## Changes | File | Change | |---|---| | `include/flashinfer/utils.cuh` | Add `ld_na_global_s16` / `st_na_global_s16` for 2-byte vectorized load and store (FP8 rope = 64 elements × 1B = 2B/thread) | | `include/flashinfer/concat_mla.cuh` | Add `ConcatMLAVecTraits<DType>` template for compile time vector type selection (BF16/FP16 to int2/int, FP8 → int or short) with `if constexpr` dispatch | | `csrc/concat_mla.cu` | Add `DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16_FP8` macro extending dispatch to FP8 E4M3/E5M2 | | `flashinfer/concat_ops.py` | Update docstring to list supported FP8 dtypes | | `tests/utils/test_concat_mla.py` | Add full pytest covering BF16, FP16, FP8-E4M3, FP8-E5M2 with bit exact correctness checks | ## Design The key insight is that `concat_mla_k` is pure memory movement, so FP8 support we adjust vectorized load and store widths: - **BF16/FP16 (2B/elem)**: nope 128 elem × 2B = 256B to `int2` (8B/thread × 32 threads), rope 64 elem × 2B = 128B to `int` (4B/thread × 32 threads) - **FP8 (1B/elem)**: nope 128 elem × 1B = 128B to `int` (4B/thread × 32 threads), rope 64 elem × 1B = 64B to `short` (2B/thread × 32 threads) `if constexpr` ensures that we do not addition runtime overhead. ## Benchmark Results End-to-end on GB300 (DeepSeek-R1-0528-FP4, DP=4, chunked prefill, ISL=128K, 16 requests): | Metric | BF16 (baseline) | FP8 (this fix) | Delta | |---|---|---|---| | Median TTFT | 42.0 s | 30.1 s | **-28.3%** | | Mean TTFT | 41.7 s | 30.5 s | **-27.0%** | | P99 TTFT | 43.8s | 33.5s | **-23.5%** | | Token throughput | 12,069 tok/s | 16,524 tok/s | **+37.0%** | ## Test Plan - [x] Unit test: `pytest tests/utils/test_concat_mla.py`, bit exact correctness for all 4 dtypes - [x] E2E crash check: ISL=128K with `use_prefill_query_quantization=true`, all succeed - [x] Performance: FP8 prefill -27% Median TTFT vs BF16 at long context - [x] No regression: BF16 baseline all succeed with identical perf to stock flashinfer <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for two FP8 formats alongside FP16 and BF16 in the concat operation. * **Documentation** * Updated docs to list supported dtypes and clarify compile-time dtype dispatch semantics. * **Refactor** * Generalized vector and memory access handling to uniformly support additional low-precision dtypes. * **Tests** * Added comprehensive tests for BF16/FP16/FP8 correctness, edge cases, strided inputs, and dtype-mismatch checks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Albert Cheng (Engrg-Hardware 1) <albecheng@login-lyris02.lyris.clusters.nvidia.com>
|
should we close this? |
|
I have already enabled FP8 in concat_mla_k for optimize long-context prefill performance in FI: flashinfer-ai/flashinfer#3129. I think we can close this PR. @pavanimajety How do you think? |
The quality regression shows up when MLA prefill query quantization is combined with an FP8 KV cache. That combination matters because use_prefill_query_quantization is inert unless the KV cache dtype is quantized and the selected MLA prefill backend claims support. In the K2.5-style config without kv-cache-dtype=fp8_e4m3 the flag does not change prefill query dtype; in the K2.6 config it does. The old MLA prefill FP8 path was not scale-correct once it became active: * forward_mha and chunked prefill converted q, k, and v with plain .to(fp8). That conversion assumes an implicit scale of 1 and ignores the calibrated q/k/v scale buffers loaded for FP8 attention. * The TRT-LLM ragged DeepSeek prefill wrapper always passed bmm1_scale=self.scale and bmm2_scale=1.0. The TRT-LLM/FlashInfer FP8 attention convention is bmm1_scale=q_scale*k_scale*sm_scale and bmm2_scale=v_scale (modulo output scale when present). Existing generic TRT-LLM attention benchmarks and tests use that convention. * For chunked context, the FP8 path used cp_gather_cache. That only copies cache bytes; the C++ comment on that path explicitly says scaled KV cache is not supported there. Scaled FP8 cache must be gathered through gather_and_maybe_dequant_cache so the cached latent KV is descaled before the kv_b projection. Fix this by making MLA FP8 prefill explicit about scale support: * Restrict backend_supports_prefill_query_quantization to TRTLLM_RAGGED. FlashInfer, FlashAttention, and TokenSpeed MLA prefill do not currently expose the BMM scale hooks needed for scaled FP8 inputs, so the flag should not activate FP8 prefill for those backends. * Add MLAAttention._scaled_fp8_prefill_input, which uses the existing static QuantFP8 op and layer scale buffers instead of plain casts. * Pass the MLAAttention layer into the MHA-style impl path so the common MLA impl can reuse the layer-owned quant op and q/k/v scale tensors. * Quantize prefill q with _q_scale, expanded k with _k_scale, and projected v with _v_scale. This differs from MLA decode, where BMM2 uses _k_scale because decode attends over the latent KV cache rather than the projected value tensor. * Plumb optional q/k/v descale factors through the MLA prefill backend API. Non-supporting backends assert that no scaled FP8 inputs were provided. TRTLLM_RAGGED maps the descales to bmm1_scale=self.scale*q_scale*k_scale and bmm2_scale=v_scale. * Allocate the non-DCP chunked prefill workspace in model dtype and always use gather_and_maybe_dequant_cache for chunked context, including FP8 KV cache. Correctness evidence: * The activation condition is covered by unit tests: FP8 prefill query dtype is selected only when the cache dtype is FP8, use_prefill_query_quantization is set, the device family is SM100, and the MLA prefill backend is TRTLLM_RAGGED. * The backend gate test proves FlashInfer, FlashAttention, and TokenSpeed no longer accidentally activate the scaled-FP8 path. * The TRTLLM_RAGGED scale helper test checks both the unscaled legacy case and the scaled case, including that partial scale plumbing is rejected. * The scale formulas match existing TRT-LLM FP8 attention tests and benchmarks: bmm1 receives q_scale*k_scale*softmax_scale, and bmm2 receives value scale. * The chunked-context cache gather now uses the only gather path that accepts a KV scale and can dequantize scaled FP8 cache before projection. Validation run locally: * .venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q -> 22 passed, 16 warnings * pre-commit run ruff-format --files <changed files> -> passed * pre-commit run ruff-check --files <changed files> -> passed * .venv/bin/python -m py_compile <changed files> -> passed * pre-commit run mypy-local --files <changed files> -> passed This is related to, but not duplicative of, upstream PR vllm-project#39841 and vllm-project#40304. vllm-project#39841 only changes FP8 cast ordering in chunked prefill, and vllm-project#40304 is about static FP8 prefill output. This change addresses scaled FP8 input correctness and backend gating for MLA prefill. Co-authored-by: OpenAI Codex <codex@openai.com>
The quality regression shows up when MLA prefill query quantization is combined with an FP8 KV cache. That combination matters because use_prefill_query_quantization is inert unless the KV cache dtype is quantized and the selected MLA prefill backend claims support. In the K2.5-style config without kv-cache-dtype=fp8_e4m3 the flag does not change prefill query dtype; in the K2.6 config it does. The old MLA prefill FP8 path was not scale-correct once it became active: * forward_mha and chunked prefill converted q, k, and v with plain .to(fp8). That conversion assumes an implicit scale of 1 and ignores the calibrated q/k/v scale buffers loaded for FP8 attention. * The TRT-LLM ragged DeepSeek prefill wrapper always passed bmm1_scale=self.scale and bmm2_scale=1.0. The TRT-LLM/FlashInfer FP8 attention convention is bmm1_scale=q_scale*k_scale*sm_scale and bmm2_scale=v_scale (modulo output scale when present). Existing generic TRT-LLM attention benchmarks and tests use that convention. * For chunked context, the FP8 path used cp_gather_cache. That only copies cache bytes; the C++ comment on that path explicitly says scaled KV cache is not supported there. Scaled FP8 cache must be gathered through gather_and_maybe_dequant_cache so the cached latent KV is descaled before the kv_b projection. Fix this by making MLA FP8 prefill explicit about scale support: * Restrict backend_supports_prefill_query_quantization to TRTLLM_RAGGED. FlashInfer, FlashAttention, and TokenSpeed MLA prefill do not currently expose the BMM scale hooks needed for scaled FP8 inputs, so the flag should not activate FP8 prefill for those backends. * Add MLAAttention._scaled_fp8_prefill_input, which uses the existing static QuantFP8 op and layer scale buffers instead of plain casts. * Pass the MLAAttention layer into the MHA-style impl path so the common MLA impl can reuse the layer-owned quant op and q/k/v scale tensors. * Quantize prefill q with _q_scale, expanded k with _k_scale, and projected v with _v_scale. This differs from MLA decode, where BMM2 uses _k_scale because decode attends over the latent KV cache rather than the projected value tensor. * Plumb optional q/k/v descale factors through the MLA prefill backend API. Non-supporting backends assert that no scaled FP8 inputs were provided. TRTLLM_RAGGED maps the descales to bmm1_scale=self.scale*q_scale*k_scale and bmm2_scale=v_scale. * Allocate the non-DCP chunked prefill workspace in model dtype and always use gather_and_maybe_dequant_cache for chunked context, including FP8 KV cache. Correctness evidence: * The activation condition is covered by unit tests: FP8 prefill query dtype is selected only when the cache dtype is FP8, use_prefill_query_quantization is set, the device family is SM100, and the MLA prefill backend is TRTLLM_RAGGED. * The backend gate test proves FlashInfer, FlashAttention, and TokenSpeed no longer accidentally activate the scaled-FP8 path. * The non-TRT prefill backend tests prove those backends fail closed if scaled FP8 q/k/v inputs are accidentally passed through the new API. * The scaled prefill input helper test proves the quant path flattens only the leading dimensions, preserves the original shape, makes the quant input contiguous, and passes the supplied scale tensor to the static QuantFP8 op. * The TRTLLM_RAGGED scale helper test checks both the unscaled legacy case and the scaled case, including that partial scale plumbing is rejected. * The scale formulas match existing TRT-LLM FP8 attention tests and benchmarks: bmm1 receives q_scale*k_scale*softmax_scale, and bmm2 receives value scale. * The chunked-context cache gather now uses the only gather path that accepts a KV scale and can dequantize scaled FP8 cache before projection. Validation run locally: * .venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q -> 26 passed, 16 warnings * pre-commit run ruff-format --files <changed files> -> passed * pre-commit run ruff-check --files <changed files> -> passed * .venv/bin/python -m py_compile <changed files> -> passed * pre-commit run mypy-local --files <changed files> -> passed This is related to, but not duplicative of, upstream PR vllm-project#39841 and vllm-project#40304. vllm-project#39841 only changes FP8 cast ordering in chunked prefill, and vllm-project#40304 is about static FP8 prefill output. This change addresses scaled FP8 input correctness and backend gating for MLA prefill. Co-authored-by: OpenAI Codex <codex@openai.com>
The quality regression shows up when MLA prefill query quantization is combined with an FP8 KV cache. That combination matters because use_prefill_query_quantization is inert unless the KV cache dtype is quantized and the selected MLA prefill backend claims support. In the K2.5-style config without kv-cache-dtype=fp8_e4m3 the flag does not change prefill query dtype; in the K2.6 config it does. The old MLA prefill FP8 path was not scale-correct once it became active: * forward_mha and chunked prefill converted q, k, and v with plain .to(fp8). That conversion assumes an implicit scale of 1 and ignores the calibrated q/k/v scale buffers loaded for FP8 attention. * The TRT-LLM ragged DeepSeek prefill wrapper always passed bmm1_scale=self.scale and bmm2_scale=1.0. The TRT-LLM/FlashInfer FP8 attention convention is bmm1_scale=q_scale*k_scale*sm_scale and bmm2_scale=v_scale (modulo output scale when present). Existing generic TRT-LLM attention benchmarks and tests use that convention. * For chunked context, the FP8 path used cp_gather_cache. That only copies cache bytes; the C++ comment on that path explicitly says scaled KV cache is not supported there. Scaled FP8 cache must be gathered through gather_and_maybe_dequant_cache so the cached latent KV is descaled before the kv_b projection. Fix this by making MLA FP8 prefill explicit about scale support: * Restrict backend_supports_prefill_query_quantization to TRTLLM_RAGGED. FlashInfer, FlashAttention, and TokenSpeed MLA prefill do not currently expose the BMM scale hooks needed for scaled FP8 inputs, so the flag should not activate FP8 prefill for those backends. * Add MLAAttention._scaled_fp8_prefill_input, which uses the existing static QuantFP8 op and layer scale buffers instead of plain casts. * Pass the MLAAttention layer into the MHA-style impl path so the common MLA impl can reuse the layer-owned quant op and q/k/v scale tensors. * Quantize prefill q with _q_scale, expanded k with _k_scale, and projected v with _v_scale. This differs from MLA decode, where BMM2 uses _k_scale because decode attends over the latent KV cache rather than the projected value tensor. * Plumb optional q/k/v descale factors through the MLA prefill backend API. Non-supporting backends assert that no scaled FP8 inputs were provided. TRTLLM_RAGGED maps the descales to bmm1_scale=self.scale*q_scale*k_scale and bmm2_scale=v_scale. * Allocate the non-DCP chunked prefill workspace in model dtype and always use gather_and_maybe_dequant_cache for chunked context, including FP8 KV cache. Correctness evidence: * The activation condition is covered by unit tests: FP8 prefill query dtype is selected only when the cache dtype is FP8, use_prefill_query_quantization is set, the device family is SM100, and the MLA prefill backend is TRTLLM_RAGGED. * The backend gate test proves FlashInfer, FlashAttention, and TokenSpeed no longer accidentally activate the scaled-FP8 path. * The non-TRT prefill backend tests prove those backends fail closed if scaled FP8 q/k/v inputs are accidentally passed through the new API. * The scaled prefill input helper test proves the quant path flattens only the leading dimensions, preserves the original shape, makes the quant input contiguous, and passes the supplied scale tensor to the static QuantFP8 op. * The TRTLLM_RAGGED scale helper test checks both the unscaled legacy case and the scaled case, including that partial scale plumbing is rejected. * The scale formulas match existing TRT-LLM FP8 attention tests and benchmarks: bmm1 receives q_scale*k_scale*softmax_scale, and bmm2 receives value scale. * The chunked-context cache gather now uses the only gather path that accepts a KV scale and can dequantize scaled FP8 cache before projection. Validation run locally: * .venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q -> 26 passed, 16 warnings * pre-commit run ruff-format --files <changed files> -> passed * pre-commit run ruff-check --files <changed files> -> passed * .venv/bin/python -m py_compile <changed files> -> passed * pre-commit run mypy-local --files <changed files> -> passed Duplicate-work check: * vllm-project#39841 only changes FP8 cast ordering in chunked prefill; it does not address calibrated q/k/v input quantization, TRT-LLM ragged BMM descales, or backend gating. * vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output / merge-state fusion, while this change fixes scaled FP8 input correctness. * vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not the Blackwell TRT-LLM ragged MLA prefill path touched here. * vllm-project#40609 / vllm-project#34795 are DCP FP8-KV efforts. This change intentionally handles non-DCP chunked context and keeps DCP as separate work. Co-authored-by: OpenAI Codex <codex@openai.com>
The quality regression shows up for public MLA model configurations when
prefill query quantization is combined with an FP8 KV cache. That combination
matters because use_prefill_query_quantization is inert unless the KV cache dtype
is quantized and the selected MLA prefill backend claims support. A serve
configuration that omits --kv-cache-dtype=fp8_e4m3 therefore does not exercise
this path, while the same model with FP8 KV cache and TRTLLM_RAGGED prefill does.
The old MLA prefill FP8 path was not scale-correct once it became active:
* forward_mha and chunked prefill converted q, k, and v with plain .to(fp8).
That conversion assumes an implicit scale of 1 and ignores the calibrated
q/k/v scale buffers loaded for FP8 attention.
* The TRT-LLM ragged DeepSeek prefill wrapper always passed
bmm1_scale=self.scale and bmm2_scale=1.0. The TRT-LLM/FlashInfer FP8
attention convention is bmm1_scale=q_scale*k_scale*sm_scale and
bmm2_scale=v_scale (modulo output scale when present). Existing generic
TRT-LLM attention benchmarks and tests use that convention.
* For chunked context, the FP8 path used cp_gather_cache. That only copies
cache bytes; the C++ comment on that path explicitly says scaled KV cache is
not supported there. Scaled FP8 cache must be gathered through
gather_and_maybe_dequant_cache so the cached latent KV is descaled before the
kv_b projection.
Fix this by making MLA FP8 prefill explicit about scale support:
* Restrict backend_supports_prefill_query_quantization to TRTLLM_RAGGED.
FlashInfer, FlashAttention, and TokenSpeed MLA prefill do not currently expose
the BMM scale hooks needed for scaled FP8 inputs, so the flag should not
activate FP8 prefill for those backends.
* Add MLAAttention._scaled_fp8_prefill_input, which uses the existing static
QuantFP8 op and layer scale buffers instead of plain casts.
* Pass the MLAAttention layer into the MHA-style impl path so the common MLA impl
can reuse the layer-owned quant op and q/k/v scale tensors.
* Quantize prefill q with _q_scale, expanded k with _k_scale, and projected v
with _v_scale. This differs from MLA decode, where BMM2 uses _k_scale because
decode attends over the latent KV cache rather than the projected value tensor.
* Plumb optional q/k/v descale factors through the MLA prefill backend API.
Non-supporting backends assert that no scaled FP8 inputs were provided.
TRTLLM_RAGGED maps the descales to bmm1_scale=self.scale*q_scale*k_scale and
bmm2_scale=v_scale.
* Allocate the non-DCP chunked prefill workspace in model dtype and always use
gather_and_maybe_dequant_cache for chunked context, including FP8 KV cache.
Open-source reproduction sketch:
Use a public Moonshot MLA checkpoint that exercises the Kimi-K2.x MLA path, for
example moonshotai/Kimi-K2.5 or the corresponding Moonshot Kimi-K2.6 Hugging
Face checkpoint when available. Choose TP/PP sizes that fit the local Blackwell
host; the important repro knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and
use_prefill_query_quantization.
Baseline server, FP8 KV cache without prefill query quantization:
MODEL_ID=moonshotai/Kimi-K2.5
vllm serve "$MODEL_ID" \
--host 0.0.0.0 \
--port 8000 \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.5
vllm serve "$MODEL_ID" \
--host 0.0.0.0 \
--port 8000 \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A quick deterministic quality probe is to ask small math questions with known
answers and temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: 17_b=b+7 and 97_b=9b+7, so b+7 divides 56;
the divisors greater than 16 are 28 and 56, giving bases 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with m=n+2, divisibility reduces to m | 39,
so n is 1, 11, or 37.
Before this fix, the FP8-KV + prefill-query-quantized server can produce empty,
garbled, timed-out, or mathematically incorrect responses while the baseline
server remains coherent. After this fix, the scaled-FP8 prefill server should
match the baseline quality on these probes. A control run without
--kv-cache-dtype=fp8_e4m3 should also remain unaffected by the flag, because the
FP8 prefill query path is intentionally inactive unless the cache is quantized.
Correctness evidence:
* The activation condition is covered by unit tests: FP8 prefill query dtype is
selected only when the cache dtype is FP8, use_prefill_query_quantization is
set, the device family is SM100, and the MLA prefill backend is TRTLLM_RAGGED.
* The backend gate test proves FlashInfer, FlashAttention, and TokenSpeed no
longer accidentally activate the scaled-FP8 path.
* The non-TRT prefill backend tests prove those backends fail closed if scaled
FP8 q/k/v inputs are accidentally passed through the new API.
* The scaled prefill input helper test proves the quant path flattens only the
leading dimensions, preserves the original shape, makes the quant input
contiguous, and passes the supplied scale tensor to the static QuantFP8 op.
* The TRTLLM_RAGGED scale helper test checks both the unscaled legacy case and
the scaled case, including that partial scale plumbing is rejected.
* The scale formulas match existing TRT-LLM FP8 attention tests and benchmarks:
bmm1 receives q_scale*k_scale*softmax_scale, and bmm2 receives value scale.
* The chunked-context cache gather now uses the only gather path that accepts a
KV scale and can dequantize scaled FP8 cache before projection.
Validation run locally:
* .venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q
-> 26 passed, 16 warnings
* pre-commit run ruff-format --files <changed files> -> passed
* pre-commit run ruff-check --files <changed files> -> passed
* .venv/bin/python -m py_compile <changed files> -> passed
* pre-commit run mypy-local --files <changed files> -> passed
Duplicate-work check refreshed on 2026-05-15:
* vllm-project#39841 only changes FP8 cast ordering in chunked prefill; it does not address
calibrated q/k/v input quantization, TRT-LLM ragged BMM descales, or backend
gating.
* vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output / merge-state fusion,
while this change fixes scaled FP8 input correctness.
* vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not the Blackwell
TRT-LLM ragged MLA prefill path touched here.
* vllm-project#40609 / vllm-project#34795 are DCP FP8-KV efforts. This change intentionally handles
non-DCP chunked context and keeps DCP as separate work.
* vllm-project#41568 lifts decode Q-prep work out of forward_impl; it is decode performance
work and does not repair prefill q/k/v input descales.
Co-authored-by: OpenAI Codex <codex@openai.com>
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend supports scaled FP8 inputs. One observed public configuration is Moonshot MLA checkpoints such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`. The old FP8 prefill path was not scale-correct once active: prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received incomplete BMM scales, and non-DCP chunked context gathered FP8 cache bytes through a path that cannot dequantize scaled KV cache before the `kv_b` projection. That can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent. This change makes scaled FP8 MLA prefill explicit: - gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks; - quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts; - plumb q/k/v descale factors through the MLA prefill backend API, mapping TRTLLM_RAGGED to `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; - fail closed in backends that do not support scaled FP8 inputs; - use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection. Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering. Related upstream work reviewed: - vllm-project#39841 adjusts FP8 cast ordering in chunked prefill, but does not address calibrated q/k/v quantization, TRTLLM_RAGGED BMM scales, or backend gating. - vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output and merge-state fusion, not scaled FP8 input correctness. - vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not Blackwell TRTLLM_RAGGED MLA prefill. - vllm-project#40609 and vllm-project#34795 cover DCP FP8-KV work; this change covers the non-DCP chunked-context path. - vllm-project#41568 is decode Q-prep refactoring and does not repair prefill input descales.
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend supports scaled FP8 inputs. One observed public configuration is Moonshot MLA checkpoints such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old FP8 prefill path was not scale-correct once active: prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received incomplete BMM scales, and non-DCP chunked context gathered FP8 cache bytes through a path that cannot dequantize scaled KV cache before the `kv_b` projection. That can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API, mapping TRTLLM_RAGGED to `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 adjusts FP8 cast ordering in chunked prefill, but does not address calibrated q/k/v quantization, TRTLLM_RAGGED BMM scales, or backend gating.
- vllm-project#40304 and vllm-project#40908 focus on static FP8 prefill output and merge-state fusion, not scaled FP8 input correctness.
- vllm-project#42509 is ROCm/AITER dense MLA prefill work on gfx950, not Blackwell TRTLLM_RAGGED MLA prefill.
- vllm-project#40609 and vllm-project#34795 cover DCP FP8-KV work; this change covers the non-DCP chunked-context path.
- vllm-project#41568 is decode Q-prep refactoring and does not repair prefill input descales.
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.
The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:
expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
expected output = softmax(expected scores) @ (v_fp8 * v_scale)
Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.
Make scaled FP8 MLA prefill explicit end to end:
- quantize prefill q, expanded k, and projected v with the layer-owned static
QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
prefill and chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
`bmm1_scale = softmax_scale * q_scale * k_scale` and
`bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
`flashinfer-python==0.6.8.post1`, whose
`BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
the BF16 output is algebraically equivalent to dequantizing Q/K before QK
and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
projection with `gather_and_maybe_dequant_cache`.
This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN remains unsupported instead of silently producing unscaled FP8 math.
Reproduction sketch:
Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.
Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, and chunked-context
cache gather/dequantization.
Local validation:
- `.venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q`
- `uvx ruff check` on the touched FP8 MLA files and tests
- `.venv/bin/python -m py_compile` on the touched FP8 MLA files and tests
- `git diff --check`
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
FP8 cast, but it still leaves missing q/k/v descales and backend scale
plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
Those operate after the attention inputs; they do not repair incorrectly
scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
quantization and prefill backend scale plumbing, not decode Q preparation.
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.
The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:
expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
expected output = softmax(expected scores) @ (v_fp8 * v_scale)
Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.
Make scaled FP8 MLA prefill explicit end to end:
- quantize prefill q, expanded k, and projected v with the layer-owned static
QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
prefill and chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
`bmm1_scale = softmax_scale * q_scale * k_scale` and
`bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
`flashinfer-python==0.6.8.post1`, whose
`BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
the BF16 output is algebraically equivalent to dequantizing Q/K before QK
and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
projection with `gather_and_maybe_dequant_cache`.
This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN remains unsupported instead of silently producing unscaled FP8 math.
Reproduction sketch:
Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.
Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, and chunked-context
cache gather/dequantization.
Local validation:
- `.venv/bin/python -m pytest tests/v1/attention/test_mla_prefill_selector.py -q`
- `uvx ruff check` on the touched FP8 MLA files and tests
- `.venv/bin/python -m py_compile` on the touched FP8 MLA files and tests
- `git diff --check`
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
FP8 cast, but it still leaves missing q/k/v descales and backend scale
plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
Those operate after the attention inputs; they do not repair incorrectly
scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
quantization and prefill backend scale plumbing, not decode Q preparation.
Co-authored-by: Codex
Prefill query quantization changes the MLA prefill contract when the KV cache
is FP8. The q/k/v tensors passed to the prefill backend are no longer ordinary
BF16 inputs: they are static per-tensor FP8 values and the backend must apply
their calibrated descales in the attention math.
The previous path converted prefill q/k/v to FP8 but did not consistently pass
those descales into the selected prefill backend. That makes the attention
scores and values numerically wrong:
expected scores = (q_fp8 * q_scale) @ (k_fp8 * k_scale).T * softmax_scale
expected output = softmax(expected scores) @ (v_fp8 * v_scale)
Without the descales, the kernel effectively computes logits and P@V in the
encoded FP8 domain. In practice this can show up as empty, garbled, timed-out,
or mathematically incorrect responses for Moonshot MLA checkpoints served with
FP8 KV cache and `use_prefill_query_quantization=true`, while the same model
without prefill query quantization remains coherent.
Make scaled FP8 MLA prefill explicit end to end:
- quantize prefill q, expanded k, and projected v with the layer-owned static
QuantFP8 op and calibrated q/k/v scale tensors instead of raw dtype casts;
- pass q/k/v descales through the MLA prefill backend API from both new-token
prefill and non-DCP chunked-context prefill;
- map descales to TRT-LLM ragged DeepSeek prefill as
`bmm1_scale = softmax_scale * q_scale * k_scale` and
`bmm2_scale = v_scale`;
- forward descales to FlashInfer ragged prefill. vLLM pins
`flashinfer-python==0.6.8.post1`, whose
`BatchPrefillWithRaggedKVCacheWrapper.run()` accepts q/k/v scale arguments,
and vLLM constructs that wrapper with `backend="cutlass"`, whose run path
passes those scales to `fmha_varlen`;
- support TokenSpeed MLA prefill without changing its external API. TokenSpeed
exposes `softmax_scale` and returns BF16 output; for static per-tensor FP8,
folding `q_scale * k_scale` into `softmax_scale` and applying `v_scale` to
the BF16 output is algebraically equivalent to dequantizing Q/K before QK
and V before PV. LSE is left unchanged by the V/output scale;
- keep FlashAttention MLA prefill fail-closed for scaled FP8 inputs. Although
the FA3 varlen API has descale slots, this GB200 path uses the FA4 wrapper,
and the current FA4 call does not pass q/k/v descales;
- dequantize scaled FP8 KV cache in non-DCP chunked context before the `kv_b`
projection with `gather_and_maybe_dequant_cache`.
DCP chunked-context MLA prefill still uses a separate distributed gather path
that does not thread q/k/v descales into context attention. Rather than silently
claim scale-correct support there, reject DCP chunked-context when scaled FP8
prefill query quantization is active. Also keep the ROCm/AITER `forward_mha`
override signature compatible with the widened common MHA interface so the
shared `layer` argument does not break fallback paths.
This preserves the user-facing MLA prefill backend support matrix where the
backend has a concrete scale-correct path: TRTLLM_RAGGED, FLASHINFER, and
TOKENSPEED_MLA remain supported for FP8 prefill query quantization on GB200;
FLASH_ATTN and DCP chunked-context remain unsupported instead of silently
producing unscaled FP8 math.
Reproduction sketch:
Use a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or
`moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV
cache, MLA prefill query quantization, and an MLA prefill backend that accepts
or can equivalently apply q/k/v descales.
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so
`b + 7` divides 56 and the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to
`m | 39`, so `n` is 1, 11, or 37.
Tests cover the activation gate, q/k/v scale plumbing, TRT-LLM BMM scale
mapping, FlashInfer scale forwarding, TokenSpeed scale equivalence, scaled FP8
input quantization, FlashAttention fail-closed behavior, DCP fail-closed
behavior, ROCm/AITER interface compatibility, and chunked-context cache
gather/dequantization.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the
FP8 cast, but it still leaves missing q/k/v descales and backend scale
plumbing.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion.
Those operate after the attention inputs; they do not repair incorrectly
scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware
and backend path from NVIDIA Blackwell MLA prefill.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize
behavior in the DCP path. That does not make non-DCP scaled FP8 prefill
query quantization scale-correct.
- vllm-project#41568 refactors decode Q preparation. This bug is in prefill input
quantization and prefill backend scale plumbing, not decode Q preparation.
Co-authored-by: Codex <codex@openai.com>
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Assisted-by: OpenAI Codex
Prefill query quantization only affects MLA when the KV cache is quantized and the selected prefill backend can consume scaled FP8 inputs. One public configuration that exercises this path is a Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` served with FP8 KV cache, TRTLLM_RAGGED MLA prefill, and `use_prefill_query_quantization=true`.
The old path treated FP8 prefill as a dtype conversion rather than a scaled quantization contract. Prefill q/k/v were converted with raw `.to(fp8)`, TRTLLM_RAGGED received legacy BMM scales, and non-DCP chunked context copied FP8 cache bytes through a gather path that cannot dequantize scaled KV cache before the `kv_b` projection. Once active, that can produce empty, garbled, timed-out, or mathematically incorrect responses while the same model without prefill query quantization remains coherent.
For scaled FP8 prefill inputs, the backend must interpret the encoded tensors with their calibrated descales. In the TRTLLM_RAGGED path that means `bmm1_scale = softmax_scale * q_scale * k_scale` and `bmm2_scale = v_scale`; passing only `softmax_scale` and `1.0` is correct only for the unscaled legacy path.
This change makes scaled FP8 MLA prefill explicit:
- gate prefill query quantization to TRTLLM_RAGGED, the backend that exposes the needed BMM scale hooks;
- quantize prefill q, expanded k, and projected v through the layer-owned static QuantFP8 op and calibrated scale tensors instead of raw casts;
- plumb q/k/v descale factors through the MLA prefill backend API and map them to the TRTLLM_RAGGED BMM scales;
- fail closed in backends that do not support scaled FP8 inputs;
- use `gather_and_maybe_dequant_cache` for non-DCP chunked context so scaled FP8 KV cache is dequantized before projection.
Reproduction sketch:
Use a public Moonshot MLA checkpoint such as `moonshotai/Kimi-K2.5` or `moonshotai/Kimi-K2.6` on a Blackwell host. The important knobs are FP8 KV cache, TRTLLM_RAGGED MLA prefill, and toggling `use_prefill_query_quantization`.
Baseline server:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":false}' \
--served-model-name kimi-mla-fp8-baseline
Scaled-FP8 prefill server under test:
MODEL_ID=moonshotai/Kimi-K2.6
vllm serve "$MODEL_ID" \
--tensor-parallel-size 8 \
--trust-remote-code \
--max-model-len 32768 \
--kv-cache-dtype fp8_e4m3 \
--attention-config '{"mla_prefill_backend":"TRTLLM_RAGGED","use_prefill_query_quantization":true}' \
--served-model-name kimi-mla-fp8-pqq
A deterministic quality probe is to ask a small math question with temperature 0:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all integer bases b>9 for which 17_b is a divisor of 97_b. Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 70: `17_b = b + 7` and `97_b = 9b + 7`, so `b + 7` divides 56; the valid bases are 21 and 49.
A second probe is:
curl -sS http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model":"kimi-mla-fp8-pqq",
"messages":[{"role":"user","content":"Find the sum of all positive integers n such that n+2 divides 3(n+3)(n^2+9). Give the final answer only at the end."}],
"temperature":0,
"max_tokens":2048
}'
The expected final answer is 49: with `m = n + 2`, divisibility reduces to `m | 39`, so `n` is 1, 11, or 37.
Before this fix, the scaled-FP8 prefill server can produce empty, garbled, timed-out, or mathematically incorrect responses while the baseline server remains coherent. A control run without `--kv-cache-dtype=fp8_e4m3` should remain unaffected by the flag because the FP8 prefill query path is inactive unless the cache is quantized.
This revision also fails closed for modes that need separate scale semantics before FP8 prefill query quantization can be correct: decode context parallelism, runtime `calculate_kv_scales`, and the ROCm AITER MLA override whose `forward_mha` signature must match the shared `layer` argument. Those cases now fall back to model-dtype prefill or the parent implementation instead of silently running an incomplete scaled-FP8 path.
Tests cover the activation gate, backend fail-closed behavior, scaled FP8 quant helper shape/contiguity/scale use, TRTLLM_RAGGED scale mapping including partial-scale rejection, and chunked-context cache gathering.
Related upstream work reviewed:
- vllm-project#39841 fixes chunked-prefill FP8 cast ordering so K concat happens before the FP8 cast, but it still leaves raw FP8 casts, missing q/k/v descales, missing TRTLLM_RAGGED BMM scale mapping, and no backend gate.
- vllm-project#40304 and vllm-project#40908 improve static FP8 prefill output and merge-state fusion. Those operate after the attention inputs; they do not repair incorrectly scaled QK logits or P@V values.
- vllm-project#42509 adds ROCm/AITER dense MLA FP8 prefill on gfx950, a different hardware and backend path from NVIDIA Blackwell TRTLLM_RAGGED.
- vllm-project#40609 and vllm-project#34795 improve DCP FP8-KV handling, especially gather/dequantize behavior in the DCP path. That is useful DCP work, but it does not make non-DCP TRTLLM_RAGGED prefill query quantization scale-correct, and DCP prefill-query FP8 support is a separate surface.
- vllm-project#41568 refactors decode Q-prep. This bug is in prefill input quantization and prefill backend scale plumbing, not decode Q preparation.
Assisted-by: OpenAI Codex
|
closing this PR per @qiching's comment |
Summary
Optimize in
_compute_prefill_contextwhere FP8 quantization of K/V tensors happens before_concat_k_nope_k_pe, causingflashinfer_concat_mla_kto crash since it only supports BF16/FP16 inputs. This makes FP8 prefill (use_prefill_query_quantization: true) completely broken for any workload that requires chunked prefill (i.e., 128K long-context sequences).Root Cause
In
_compute_prefill_context(the chunked prefill path), the original code castskv_nopeandk_peto FP8 before the concat operation:However,
flashinfer_concat_mla_kusesDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16, which only dispatches BF16/FP16, and also requires all three tensors to share the same dtype.pass FP8 tensors causes the vLLM server to crash, making FP8 chunked prefill can not use normally.
Fix
k_pe(FP8 from workspace) matchesk_nope(BF16 fromkv_b_proj) dtype before concatkandvto FP8 after concat, aligning withforward_mhaImpact
enable-chunked-prefill+use_prefill_query_quantization(FP8 prefill) crashes. This includes all long-context sequences (e.g., ISL >= 4K) that require chunking.forward_mhawhich was already correct). the bf16 path in_compute_prefill_contextis also unchanged.Benchmark Results
end-to-end A/B comparison on GB200 (DeepSeek-R1-0528-FP4, DP=4, chunked prefill, ISL=128K, 16 requests):
Test Plan
test_mla_backends.pytests pass