[Feature] Add FP4 KV cache support for SM120 GPUs#21601
[Feature] Add FP4 KV cache support for SM120 GPUs#21601samuellees wants to merge 6 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements NVFP4 KV cache support for Blackwell GPUs (SM100/SM120), featuring a two-level scaling scheme to reduce memory overhead. The implementation includes a new quantization strategy pattern, FlashInfer and TRT-LLM XQA kernel integration, and memory pool updates. Review feedback identifies a critical regression in backend dispatching, potential silent failures in CUDA kernels on incompatible hardware, and several code quality improvements, including the removal of magic numbers and outdated TODO comments.
| return triton_w8a8_block_fp8_linear | ||
|
|
There was a problem hiding this comment.
The function _dispatch_auto_backend now unconditionally returns triton_w8a8_block_fp8_linear, making the subsequent backend selection logic for DeepGEMM, FlashInfer, etc., unreachable. This appears to be a temporary change for debugging and will break the intended automatic backend dispatching. This should be removed.
| #if HAS_FP8_SUPPORT | ||
| const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0)); | ||
| const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1)); | ||
| #else | ||
| const float scale_0 = 1.0f; | ||
| const float scale_1 = 1.0f; | ||
| #endif |
There was a problem hiding this comment.
The CUDA kernel nvfp4_dequant_vectorized_kernel has a fallback for when HAS_FP8_SUPPORT is false, which sets the scales to 1.0f. This will produce incorrect dequantization results silently on hardware that doesn't support FP8, as it ignores the block scales. The kernel should instead fail with an error or there should be a compile-time assertion if FP8 support is required but not available.
| #if HAS_FP8_SUPPORT | |
| const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0)); | |
| const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1)); | |
| #else | |
| const float scale_0 = 1.0f; | |
| const float scale_1 = 1.0f; | |
| #endif | |
| #if !HAS_FP8_SUPPORT | |
| #error "This kernel requires FP8 support, which is not available on this architecture." | |
| #endif | |
| const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0)); | |
| const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1)); |
| #else | ||
| return 0; | ||
| #endif |
There was a problem hiding this comment.
The fp32_vec_to_e2m1 device function returns 0 if __CUDA_ARCH__ < 1000. This will lead to incorrect quantization (all zeros) on older architectures without any warning. Since this function is critical for the SM100+ quantization kernel, it should assert that the architecture is supported.
| #else | |
| return 0; | |
| #endif | |
| #else | |
| static_assert(__CUDA_ARCH__ >= 1000, "This function requires SM100 or newer architecture."); | |
| return 0; | |
| #endif |
| NVFP4 KV cache support is on the `nvfp4-kvcache-sm120-v2` branch. Clone from the fork and install in editable mode: | ||
|
|
||
| ```bash | ||
| git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git |
There was a problem hiding this comment.
The git clone command points to a personal fork (github.com/samuellees/sglang.git). For official documentation, this should be updated to point to the main project repository's branch before this pull request is merged.
| git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git | |
| git clone -b nvfp4-kvcache-sm120-v2 https://github.com/sgl-project/sglang.git |
| if sum(paged_seq_lens_cpu) > 0: | ||
| # [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size] | ||
| paged_seq_lens_cpu.append(256) | ||
| import numpy as np |
|
|
||
| # logger.debug(f"[KERNEL DEBUG] ====== End of Parameter Dump ======") | ||
|
|
||
| # TODO(Sam): NVFP4 kv cache is not supported or MTP. Because draft extend will invoke this api, it needs nvfp4 kv cache support. |
There was a problem hiding this comment.
This TODO comment appears to be outdated or confusing. It states that NVFP4 KV cache is not supported for MTP, but this pull request seems to add this support by routing draft_extend and target_verify through the XQA decode kernel. Please remove or update this comment to reflect the current implementation.
| if self.sm_version == 100: | ||
| k_scale *= 6.0 | ||
| v_scale *= 6.0 |
There was a problem hiding this comment.
The code applies a hardcoded scaling factor of 6.0 for SM100 GPUs. While the comment explains this is for hardware alignment, such "magic numbers" are hard to maintain. It would be better to define this as a named constant with a more detailed explanation, and ideally a reference to NVIDIA's documentation if available.
| if self.sm_version == 100: | |
| k_scale *= 6.0 | |
| v_scale *= 6.0 | |
| # SM100 requires a 6x adjustment to align FP4 range with hardware expectations. | |
| # See [link to NVIDIA doc or further explanation if possible]. | |
| SM100_FP4_SCALE_ADJUSTMENT = 6.0 | |
| if self.sm_version == 100: | |
| k_scale *= SM100_FP4_SCALE_ADJUSTMENT | |
| v_scale *= SM100_FP4_SCALE_ADJUSTMENT |
- Enable NVFP4 KV Cache for SM100 (B200) and SM120 (RTX PRO 6000) - Extract KVCacheQuantMethod ABC with NoneMethod/NVFP4Method/MXFP4Method subclasses - Migrate quantize kernel to flashinfer fp4_quantize, keep custom CUDA dequant - Add MTP (speculative decoding) support for NVFP4/FP8 KV cache - XQA decode kernel integration with proper scale factor handling - Centralize FP4 buffer creation, quantize/dequant, and scale management
a5dd741 to
4473ed7
Compare
samuellees
left a comment
There was a problem hiding this comment.
Review for the first round
When page_size > 1 (forced to 64 by trtllm_mha decode backend), the dq_page_table used page-aligned lengths for kv_indptr, causing flashinfer's causal offset to be page_align(seq_len) - q_len instead of seq_len - q_len. This leaked up to page_size-1 future/padding tokens into each query's attention, degrading GSM8K 1319q accuracy from ~91% to ~85-87%. Fix: build dq_page_table indices that skip padding gaps, and use actual (non-padded) seq_lens for dq_paged_kernel_lens so kv_indptr reflects exact token counts. Verified: GSM8K 1319q with chunk=2048 now achieves 90.1% (was 85-87%).
Update comment to reflect actual reason: piecewise capture invokes trtllm_batch_decode_with_kv_cache with kv_cache_sf which needs flashinfer >= 0.6.7, not yet available via pip.
Resolved conflicts in: - flashinfer_backend.py: merged NVFP4 dq_page_table with piecewise extra_kv allocation - trtllm_mha_backend.py: added skip_softmax_threshold_scale_factor, kept NVFP4 paths - memory_pool.py: kept both quant_method and start_layer params in HybridLinearKVPool - model_runner_kv_cache_mixin.py: kept both params in HybridLinearKVPool construction
…esults Remove the NVFP4 piecewise CUDA graph disable rule from server_args.py. Piecewise CUDA graph is now confirmed working with NVFP4 on SM120. Update docs/nvfp4_kv_cache.md with GSM8K/GPQA/LongBenchV2/AIME25 results.
…he_mixin.py from PR branch
Refactoring Roadmap
This PR is being split into 4 smaller PRs for easier review:
kv_cache_quant_method.py,kvfp4_tensor.pymemory_pool.py,model_runner_kv_cache_mixin.pyflashinfer_backend.py,trtllm_mha_backend.pyeagle_worker.py,server_args.py, docsDependency chain: PR1 → PR2 → PR3 → PR4. Each PR is self-contained and does not break existing BF16/FP8 functionality.
See roadmap for details.
Rebased from #18314
Summary
Add NVFP4 (FP4 E2M1) KV cache quantization support for Blackwell GPUs, reducing KV cache memory by ~2x compared to FP8 with no accuracy loss on GSM8K.
Key changes
KVCacheQuantMethodABC withNoneMethod,NVFP4Method,MXFP4Methodsubclasses (kv_cache_quant_method.py). Adding a new FP4 scheme only requiresimplementing one subclass and registering it.
alongside FP4 KV data.
fp4_quantize: Replaces custom JIT CUDA kernel with flashinfer's optimizedimplementation.
target_verify/draft_extendroute through XQA decode kernelwith causal masking (
--speculative-attention-mode decode).models (e.g., Qwen3.5-35B-A3B).
Usage
python3 -m sglang.launch_server \ --model-path <model_path> \ --kv-cache-dtype fp4_e2m1 \ --prefill-attention-backend flashinfer \ --decode-attention-backend trtllm_mha \ --disable-radix-cacheBenchmark (Qwen3.5-35B-A3B, GSM8K 100q, SM120)
Throughput
Requirements
Changed files
│ File │ Change │
│ kv_cache_quant_method.py │ New — Strategy pattern ABC + NVFP4/MXFP4 subclasses │
│ kvfp4_tensor.py │ FP4 quantize/dequantize kernels, flashinfer wrapper │
│ trtllm_mha_backend.py │ XQA decode for FP4, MTP target_verify/draft_extend with causal mask │
│ flashinfer_backend.py │ NVFP4 dequant state init, FP4→FP8 prefill path │
│ memory_pool.py │ Pool integration with quant_method, scale buffer management │
│ model_runner_kv_cache_mixin.py │ NVFP4Method creation and scale loading │
│ attention_registry.py │ Split prefill/decode backend validation for Blackwell │
│ docs/nvfp4_kv_cache.md │ New — Documentation with usage, benchmarks, architecture │
Test plan
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci