Skip to content

[Feature] Add FP4 KV cache support for SM120 GPUs#21601

Open
samuellees wants to merge 6 commits intosgl-project:mainfrom
samuellees:nvfp4-kvcache-sm120
Open

[Feature] Add FP4 KV cache support for SM120 GPUs#21601
samuellees wants to merge 6 commits intosgl-project:mainfrom
samuellees:nvfp4-kvcache-sm120

Conversation

@samuellees
Copy link
Copy Markdown
Contributor

@samuellees samuellees commented Mar 28, 2026

Refactoring Roadmap

This PR is being split into 4 smaller PRs for easier review:

PR Scope Status
[1/4] Quantization strategy abstraction + kernels kv_cache_quant_method.py, kvfp4_tensor.py samuellees/sglang#nvfp4-kv-pr-4-1
[2/4] Memory pool refactoring memory_pool.py, model_runner_kv_cache_mixin.py Pending
[3/4] Attention backend integration flashinfer_backend.py, trtllm_mha_backend.py Pending
[4/4] MTP support + config + docs hybrid backends, eagle_worker.py, server_args.py, docs Pending

Dependency chain: PR1 → PR2 → PR3 → PR4. Each PR is self-contained and does not break existing BF16/FP8 functionality.

See roadmap for details.


Rebased from #18314

  • Enable NVFP4 KV Cache for SM120
  • Extract KVCacheQuantMethod ABC with NoneMethod/NVFP4Method/MXFP4Method subclasses
  • Migrate quantize kernel to flashinfer fp4_quantize, keep custom CUDA dequant
  • Add MTP (speculative decoding) support for NVFP4/FP8 KV cache
  • XQA decode kernel integration with proper scale factor handling
  • Centralize FP4 buffer creation, quantize/dequant, and scale management

Summary

Add NVFP4 (FP4 E2M1) KV cache quantization support for Blackwell GPUs, reducing KV cache memory by ~2x compared to FP8 with no accuracy loss on GSM8K.

Key changes

  • Strategy pattern for KV cache quantization: Introduce KVCacheQuantMethod ABC with NoneMethod,
    NVFP4Method, MXFP4Method subclasses (kv_cache_quant_method.py). Adding a new FP4 scheme only requires
    implementing one subclass and registering it.
  • NVFP4 two-level scaling: Per-tensor FP32 global scale + per-block FP8 E4M3 scale factors, stored
    alongside FP4 KV data.
  • Kernel dispatch:
    • Prefill: FlashInfer dequantizes FP4→FP8, then runs standard FP8 prefill kernel
    • Decode: TRT-LLM XQA kernel reads FP4 natively with two-level scales
  • Quantize via flashinfer fp4_quantize: Replaces custom JIT CUDA kernel with flashinfer's optimized
    implementation.
  • MTP (Multi-Token Prediction) support: target_verify / draft_extend route through XQA decode kernel
    with causal masking (--speculative-attention-mode decode).
  • Hybrid model support: Mamba state update works correctly under speculative attention mode for hybrid
    models (e.g., Qwen3.5-35B-A3B).

Usage

python3 -m sglang.launch_server \
    --model-path <model_path> \
    --kv-cache-dtype fp4_e2m1 \
    --prefill-attention-backend flashinfer \
    --decode-attention-backend trtllm_mha \
    --disable-radix-cache

Benchmark (Qwen3.5-35B-A3B, GSM8K 100q, SM120)

KV Cache MTP Accuracy
FP8 (fp8_e4m3) Yes 96.6%
FP4 (fp4_e2m1) Yes 97.1%

Throughput

FP8 KV Cache NVFP4 KV Cache Speedup (NVFP4 vs FP8)
Prefill Latency (160K) 8757 ms 8792 ms 0.996x
Prefill Latency (1M) 142143 ms 142325 ms 0.998x
Decode Latency (1M) 8.4 ms 7.1 ms 1.18x

Requirements

  • Blackwell GPU (SM120)
  • CUDA 13.0+, PyTorch 2.9.1+
  • FlashInfer >= 0.6.3 (built from source)

Changed files

│ File │ Change │
│ kv_cache_quant_method.py │ New — Strategy pattern ABC + NVFP4/MXFP4 subclasses │
│ kvfp4_tensor.py │ FP4 quantize/dequantize kernels, flashinfer wrapper │
│ trtllm_mha_backend.py │ XQA decode for FP4, MTP target_verify/draft_extend with causal mask │
│ flashinfer_backend.py │ NVFP4 dequant state init, FP4→FP8 prefill path │
│ memory_pool.py │ Pool integration with quant_method, scale buffer management │
│ model_runner_kv_cache_mixin.py │ NVFP4Method creation and scale loading │
│ attention_registry.py │ Split prefill/decode backend validation for Blackwell │
│ docs/nvfp4_kv_cache.md │ New — Documentation with usage, benchmarks, architecture │

Test plan

  • GSM8K accuracy check (1319 questions, FP4 vs FP8 baseline)
  • MTP + FP4 end-to-end serving test
  • Verify no regression on FP8 KV cache path
  • Verify BF16 KV cache path unaffected (NoneMethod)

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements NVFP4 KV cache support for Blackwell GPUs (SM100/SM120), featuring a two-level scaling scheme to reduce memory overhead. The implementation includes a new quantization strategy pattern, FlashInfer and TRT-LLM XQA kernel integration, and memory pool updates. Review feedback identifies a critical regression in backend dispatching, potential silent failures in CUDA kernels on incompatible hardware, and several code quality improvements, including the removal of magic numbers and outdated TODO comments.

Comment on lines +438 to +439
return triton_w8a8_block_fp8_linear

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _dispatch_auto_backend now unconditionally returns triton_w8a8_block_fp8_linear, making the subsequent backend selection logic for DeepGEMM, FlashInfer, etc., unreachable. This appears to be a temporary change for debugging and will break the intended automatic backend dispatching. This should be removed.

Comment on lines +562 to +568
#if HAS_FP8_SUPPORT
const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0));
const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1));
#else
const float scale_0 = 1.0f;
const float scale_1 = 1.0f;
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CUDA kernel nvfp4_dequant_vectorized_kernel has a fallback for when HAS_FP8_SUPPORT is false, which sets the scales to 1.0f. This will produce incorrect dequantization results silently on hardware that doesn't support FP8, as it ignores the block scales. The kernel should instead fail with an error or there should be a compile-time assertion if FP8 support is required but not available.

Suggested change
#if HAS_FP8_SUPPORT
const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0));
const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1));
#else
const float scale_0 = 1.0f;
const float scale_1 = 1.0f;
#endif
#if !HAS_FP8_SUPPORT
#error "This kernel requires FP8 support, which is not available on this architecture."
#endif
const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0));
const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1));

Comment on lines +771 to +773
#else
return 0;
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fp32_vec_to_e2m1 device function returns 0 if __CUDA_ARCH__ < 1000. This will lead to incorrect quantization (all zeros) on older architectures without any warning. Since this function is critical for the SM100+ quantization kernel, it should assert that the architecture is supported.

Suggested change
#else
return 0;
#endif
#else
static_assert(__CUDA_ARCH__ >= 1000, "This function requires SM100 or newer architecture.");
return 0;
#endif

Comment thread docs/nvfp4_kv_cache.md
NVFP4 KV cache support is on the `nvfp4-kvcache-sm120-v2` branch. Clone from the fork and install in editable mode:

```bash
git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The git clone command points to a personal fork (github.com/samuellees/sglang.git). For official documentation, this should be updated to point to the main project repository's branch before this pull request is merged.

Suggested change
git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git
git clone -b nvfp4-kvcache-sm120-v2 https://github.com/sgl-project/sglang.git

if sum(paged_seq_lens_cpu) > 0:
# [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size]
paged_seq_lens_cpu.append(256)
import numpy as np
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import numpy as np statement is inside the _prepare_nvfp4_metadata_for_extend_base method. According to PEP 8, imports should be at the top of the file. This also avoids repeated import overhead if the method is called multiple times.


# logger.debug(f"[KERNEL DEBUG] ====== End of Parameter Dump ======")

# TODO(Sam): NVFP4 kv cache is not supported or MTP. Because draft extend will invoke this api, it needs nvfp4 kv cache support.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO comment appears to be outdated or confusing. It states that NVFP4 KV cache is not supported for MTP, but this pull request seems to add this support by routing draft_extend and target_verify through the XQA decode kernel. Please remove or update this comment to reflect the current implementation.

Comment on lines +213 to +215
if self.sm_version == 100:
k_scale *= 6.0
v_scale *= 6.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code applies a hardcoded scaling factor of 6.0 for SM100 GPUs. While the comment explains this is for hardware alignment, such "magic numbers" are hard to maintain. It would be better to define this as a named constant with a more detailed explanation, and ideally a reference to NVIDIA's documentation if available.

Suggested change
if self.sm_version == 100:
k_scale *= 6.0
v_scale *= 6.0
# SM100 requires a 6x adjustment to align FP4 range with hardware expectations.
# See [link to NVIDIA doc or further explanation if possible].
SM100_FP4_SCALE_ADJUSTMENT = 6.0
if self.sm_version == 100:
k_scale *= SM100_FP4_SCALE_ADJUSTMENT
v_scale *= SM100_FP4_SCALE_ADJUSTMENT

@samuellees samuellees mentioned this pull request Mar 28, 2026
5 tasks
- Enable NVFP4 KV Cache for SM100 (B200) and SM120 (RTX PRO 6000)
- Extract KVCacheQuantMethod ABC with NoneMethod/NVFP4Method/MXFP4Method subclasses
- Migrate quantize kernel to flashinfer fp4_quantize, keep custom CUDA dequant
- Add MTP (speculative decoding) support for NVFP4/FP8 KV cache
- XQA decode kernel integration with proper scale factor handling
- Centralize FP4 buffer creation, quantize/dequant, and scale management
@samuellees samuellees force-pushed the nvfp4-kvcache-sm120 branch from a5dd741 to 4473ed7 Compare March 28, 2026 13:14
Copy link
Copy Markdown
Contributor Author

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review for the first round

When page_size > 1 (forced to 64 by trtllm_mha decode backend), the
dq_page_table used page-aligned lengths for kv_indptr, causing flashinfer's
causal offset to be page_align(seq_len) - q_len instead of seq_len - q_len.
This leaked up to page_size-1 future/padding tokens into each query's
attention, degrading GSM8K 1319q accuracy from ~91% to ~85-87%.

Fix: build dq_page_table indices that skip padding gaps, and use actual
(non-padded) seq_lens for dq_paged_kernel_lens so kv_indptr reflects exact
token counts.

Verified: GSM8K 1319q with chunk=2048 now achieves 90.1% (was 85-87%).
Update comment to reflect actual reason: piecewise capture invokes
trtllm_batch_decode_with_kv_cache with kv_cache_sf which needs
flashinfer >= 0.6.7, not yet available via pip.
Resolved conflicts in:
- flashinfer_backend.py: merged NVFP4 dq_page_table with piecewise extra_kv allocation
- trtllm_mha_backend.py: added skip_softmax_threshold_scale_factor, kept NVFP4 paths
- memory_pool.py: kept both quant_method and start_layer params in HybridLinearKVPool
- model_runner_kv_cache_mixin.py: kept both params in HybridLinearKVPool construction
…esults

Remove the NVFP4 piecewise CUDA graph disable rule from server_args.py.
Piecewise CUDA graph is now confirmed working with NVFP4 on SM120.
Update docs/nvfp4_kv_cache.md with GSM8K/GPQA/LongBenchV2/AIME25 results.
voipmonitor added a commit to voipmonitor/sglang that referenced this pull request Apr 1, 2026
voipmonitor pushed a commit to voipmonitor/sglang that referenced this pull request Apr 12, 2026
voipmonitor pushed a commit to voipmonitor/sglang that referenced this pull request Apr 13, 2026
voipmonitor pushed a commit to voipmonitor/sglang that referenced this pull request Apr 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 documentation Improvements or additions to documentation quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant