Problem
The TRITON_MLA backend raises NotImplementedError when FP8 KV cache is requested:
# triton_mla.py, __init__
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("TritonMLA V1 with FP8 KV cache not yet supported")
# triton_mla.py, forward_mqa
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
On SM12.0 (Blackwell — RTX 5080/5090, B100/B200), TRITON_MLA is the only available MLA backend:
FLASHINFER_MLA requires qk_nope_head_dim=128, but models like GLM-4.7-Flash-REAP have qk_nope_head_dim=192
CUTLASS_MLA requires SM10.x
This means FP8 KV cache is completely blocked on Blackwell for these models, even though the base class MLACommonImpl already handles FP8 cache writes via concat_and_cache_mla (with _k_scale quantization in the CUDA kernel) and FlashMLA supports FP8 on other architectures.
Impact
For VRAM-constrained GPUs (e.g. RTX 5080 with 16 GB), FP8 KV cache would double the usable context length. Tested with GLM-4.7-Flash-REAP-23B-A3B NVFP4:
- BF16 KV cache: ~4,928 tokens max context
- FP8 KV cache: ~11,728 tokens max context (2.38×)
Workaround (not suitable for upstream)
I have a working patch that pre-dequantizes FP8 cache and FP8 query to BF16 before calling the existing Triton decode kernel:
# In forward_mqa, before the Triton kernel:
if self.kv_cache_dtype.startswith("fp8"):
k_scale = layer._k_scale.float()
kv_c_and_k_pe_cache = (kv_c_and_k_pe_cache.float() * k_scale).to(torch.bfloat16)
# Query may also be FP8 (from _decode_concat_quant_fp8_op):
if q.dtype.is_floating_point and q.element_size() == 1:
q_scale = layer._q_scale.float()
q = (q.float() * q_scale).to(torch.bfloat16)
This works but is suboptimal because:
- Full cache dequantization on every decode step — the entire KV cache page is cast FP8 → float32 → BF16, scaling linearly with sequence length
- float32 intermediary required — PyTorch (2.10.0) doesn't support
.to() or arithmetic directly on Float8_e4m3fn tensors
- Query FP8 detection is fragile — uses
element_size() == 1 heuristic instead of an explicit contract
Suggested proper implementation
The ideal fix would be to handle FP8 inside the Triton decode kernel (decode_attention_fwd), similar to how FlashMLA handles it. The kernel already receives the cache tensor — it could accept a k_scale parameter and dequantize on-the-fly during the attention computation, avoiding the extra memory copy.
Alternatively, a fused dequant kernel could convert the cache page in-place before the attention kernel, which would still be much cheaper than the Python-level .float() * scale approach.
Environment
- GPU: RTX 5080 (SM12.0, 16 GB)
- vLLM: v0.16.1rc1.dev34 (commit 6283021)
- PyTorch: 2.10.0+cu128
- Model: GLM-4.7-Flash-REAP-23B-A3B quantized to NVFP4 (modelopt 0.41.0)
- Backend: TRITON_MLA (only available option on SM12.0 for this model)
Problem
The TRITON_MLA backend raises
NotImplementedErrorwhen FP8 KV cache is requested:On SM12.0 (Blackwell — RTX 5080/5090, B100/B200), TRITON_MLA is the only available MLA backend:
FLASHINFER_MLArequiresqk_nope_head_dim=128, but models like GLM-4.7-Flash-REAP haveqk_nope_head_dim=192CUTLASS_MLArequires SM10.xThis means FP8 KV cache is completely blocked on Blackwell for these models, even though the base class
MLACommonImplalready handles FP8 cache writes viaconcat_and_cache_mla(with_k_scalequantization in the CUDA kernel) and FlashMLA supports FP8 on other architectures.Impact
For VRAM-constrained GPUs (e.g. RTX 5080 with 16 GB), FP8 KV cache would double the usable context length. Tested with GLM-4.7-Flash-REAP-23B-A3B NVFP4:
Workaround (not suitable for upstream)
I have a working patch that pre-dequantizes FP8 cache and FP8 query to BF16 before calling the existing Triton decode kernel:
This works but is suboptimal because:
.to()or arithmetic directly onFloat8_e4m3fntensorselement_size() == 1heuristic instead of an explicit contractSuggested proper implementation
The ideal fix would be to handle FP8 inside the Triton decode kernel (
decode_attention_fwd), similar to how FlashMLA handles it. The kernel already receives the cache tensor — it could accept ak_scaleparameter and dequantize on-the-fly during the attention computation, avoiding the extra memory copy.Alternatively, a fused dequant kernel could convert the cache page in-place before the attention kernel, which would still be much cheaper than the Python-level
.float() * scaleapproach.Environment