Skip to content

[Feature] TRITON_MLA: support FP8 KV cache (needed for SM12.0 / Blackwell) #35577

@lucaspirola

Description

@lucaspirola

Problem

The TRITON_MLA backend raises NotImplementedError when FP8 KV cache is requested:

# triton_mla.py, __init__
if is_quantized_kv_cache(self.kv_cache_dtype):
    raise NotImplementedError("TritonMLA V1 with FP8 KV cache not yet supported")

# triton_mla.py, forward_mqa
if self.kv_cache_dtype.startswith("fp8"):
    raise NotImplementedError("FP8 Triton MLA not yet supported")

On SM12.0 (Blackwell — RTX 5080/5090, B100/B200), TRITON_MLA is the only available MLA backend:

  • FLASHINFER_MLA requires qk_nope_head_dim=128, but models like GLM-4.7-Flash-REAP have qk_nope_head_dim=192
  • CUTLASS_MLA requires SM10.x

This means FP8 KV cache is completely blocked on Blackwell for these models, even though the base class MLACommonImpl already handles FP8 cache writes via concat_and_cache_mla (with _k_scale quantization in the CUDA kernel) and FlashMLA supports FP8 on other architectures.

Impact

For VRAM-constrained GPUs (e.g. RTX 5080 with 16 GB), FP8 KV cache would double the usable context length. Tested with GLM-4.7-Flash-REAP-23B-A3B NVFP4:

  • BF16 KV cache: ~4,928 tokens max context
  • FP8 KV cache: ~11,728 tokens max context (2.38×)

Workaround (not suitable for upstream)

I have a working patch that pre-dequantizes FP8 cache and FP8 query to BF16 before calling the existing Triton decode kernel:

# In forward_mqa, before the Triton kernel:
if self.kv_cache_dtype.startswith("fp8"):
    k_scale = layer._k_scale.float()
    kv_c_and_k_pe_cache = (kv_c_and_k_pe_cache.float() * k_scale).to(torch.bfloat16)

# Query may also be FP8 (from _decode_concat_quant_fp8_op):
if q.dtype.is_floating_point and q.element_size() == 1:
    q_scale = layer._q_scale.float()
    q = (q.float() * q_scale).to(torch.bfloat16)

This works but is suboptimal because:

  1. Full cache dequantization on every decode step — the entire KV cache page is cast FP8 → float32 → BF16, scaling linearly with sequence length
  2. float32 intermediary required — PyTorch (2.10.0) doesn't support .to() or arithmetic directly on Float8_e4m3fn tensors
  3. Query FP8 detection is fragile — uses element_size() == 1 heuristic instead of an explicit contract

Suggested proper implementation

The ideal fix would be to handle FP8 inside the Triton decode kernel (decode_attention_fwd), similar to how FlashMLA handles it. The kernel already receives the cache tensor — it could accept a k_scale parameter and dequantize on-the-fly during the attention computation, avoiding the extra memory copy.

Alternatively, a fused dequant kernel could convert the cache page in-place before the attention kernel, which would still be much cheaper than the Python-level .float() * scale approach.

Environment

  • GPU: RTX 5080 (SM12.0, 16 GB)
  • vLLM: v0.16.1rc1.dev34 (commit 6283021)
  • PyTorch: 2.10.0+cu128
  • Model: GLM-4.7-Flash-REAP-23B-A3B quantized to NVFP4 (modelopt 0.41.0)
  • Backend: TRITON_MLA (only available option on SM12.0 for this model)

Metadata

Metadata

Assignees

No one assigned

    Labels

    unstaleRecieved activity after being labelled stale

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions