Skip to content

[Bugfix] Fix TRITON_MLA FP8 KV cache decode on Blackwell GPUs#35833

Closed
ricky-chaoju wants to merge 4 commits into
vllm-project:mainfrom
ricky-chaoju:triton-mla-fp8-kv-cache
Closed

[Bugfix] Fix TRITON_MLA FP8 KV cache decode on Blackwell GPUs#35833
ricky-chaoju wants to merge 4 commits into
vllm-project:mainfrom
ricky-chaoju:triton-mla-fp8-kv-cache

Conversation

@ricky-chaoju

@ricky-chaoju ricky-chaoju commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Cast query to bfloat16 before the Triton decode kernel when FP8 KV cache is enabled, avoiding FP8 tl.dot instructions that produce illegal instruction errors on Blackwell (SM 12.x).
  • Reduce Triton pipeline stages to 1 for FP8 to prevent shared memory overflow from float32 dequantization intermediates.
  • KV cache dequantization is still performed on-the-fly inside the kernel, so there is no full-cache copy overhead.

Fixes #35577

Test plan

  • Verified on NVIDIA GB10 (SM 12.1) with GLM-4.7-Flash-NVFP4, --kv-cache-dtype fp8, generation throughput ~32–80 tok/s
  • FP8 vs bfloat16 kernel output: cosine similarity 0.9997, max abs diff 0.0006
  • Memory savings confirmed: FP8 cache uses 50% of bfloat16 baseline

@mergify mergify Bot added v1 bug Something isn't working labels Mar 3, 2026
@ricky-chaoju ricky-chaoju force-pushed the triton-mla-fp8-kv-cache branch from 66a1ab8 to e7edcf5 Compare March 3, 2026 05:41

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FP8 KV cache in Triton MLA, specifically fixing an illegal instruction error on Blackwell GPUs. The changes involve casting the query to bfloat16 to avoid problematic FP8 tl.dot instructions and reducing Triton pipeline stages to prevent shared memory overflow. While the changes are generally correct, I've identified a critical issue where hardcasting the query to bfloat16 can cause a data type mismatch and a runtime error in the downstream V up-projection when the model's data type is float16. I've provided a detailed explanation and a suggested fix for this issue.

Comment thread vllm/v1/attention/backends/mla/triton_mla.py
Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
@ricky-chaoju ricky-chaoju force-pushed the triton-mla-fp8-kv-cache branch from 41bf30e to 71a42cb Compare March 3, 2026 05:48
@mergify

mergify Bot commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

Hi @ricky-chaoju, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>

Signed-off-by: "vllm-dev" <ricky.chen@infinirc.com>
@mergify

mergify Bot commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

Documentation preview: https://vllm--35833.org.readthedocs.build/en/35833/

@mergify mergify Bot added the documentation Improvements or additions to documentation label Mar 3, 2026
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 5, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
@mergify

mergify Bot commented Mar 9, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ricky-chaoju.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 9, 2026
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 18, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
@mergify mergify Bot removed the needs-rebase label Apr 23, 2026
@mergify

mergify Bot commented Apr 23, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ricky-chaoju.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 23, 2026
@ricky-chaoju

Copy link
Copy Markdown
Contributor Author

Closing since this has been superseded by #34597, with follow-up FP8 MLA scale fixes in #37054.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working documentation Improvements or additions to documentation needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] TRITON_MLA: support FP8 KV cache (needed for SM12.0 / Blackwell)

1 participant