Skip to content

[1/4] NVFP4 KV cache: quantization strategy abstraction and kernel#21954

Merged
Fridge003 merged 17 commits intosgl-project:mainfrom
samuellees:nvfp4-kv-pr-4-1
Apr 29, 2026
Merged

[1/4] NVFP4 KV cache: quantization strategy abstraction and kernel#21954
Fridge003 merged 17 commits intosgl-project:mainfrom
samuellees:nvfp4-kv-pr-4-1

Conversation

@samuellees
Copy link
Copy Markdown
Contributor

@samuellees samuellees commented Apr 2, 2026

Summary

Part 1 of 4 for NVFP4 KV cache support on SM120 GPUs. Split from #21601 for easier review.

Roadmap: PR1 (this) → PR2 (memory pool) → PR3 (attention backends) → PR4 (MTP + config + docs)

This PR introduces the FP4 KV cache quantization strategy abstraction and quantize/dequantize utilities. Pure additions — no existing code paths are modified.

Changes

  • fp4_kv_cache_quant_method.py (NEW): FP4KVCacheQuantMethod ABC with strategy pattern

    • NVFP4KVMethod — two-level scaling (per-tensor FP32 + per-block FP8 E4M3), SM100/SM120
    • BlockFP4KVMethod — block-wise single-level scaling (similar to MXFP4 but block_size=16)
    • FP4_KV_CACHE_QUANT_REGISTRY — maps recipe names to method classes
    • Methods are prefill-only (not in CUDA graph path); decode reads raw FP4 buffers via XQA kernel
  • kvfp4_tensor.py (extended):

    • NVFP4KVQuantizeUtil — quantize via FlashInfer nvfp4_kv_quantize (fallback fp4_quantize), dequantize via nvfp4_kv_dequantize (fallback PyTorch E2M1 LUT)
    • BlockFP4KVQuantizeUtil — block-wise FP4 with block_size=16 (renamed from KVFP4QuantizeUtil, backward-compat alias kept)
  • Unit tests: Registry, buffer shapes, quantize→dequantize roundtrip for NVFP4 (CUDA) and BlockFP4 (CPU)

Design

FP4KVCacheQuantMethod (pure compute) → Pool (buffer + batch dequant) → Backend (view adaptation)

The quant method owns quantize/dequantize logic. The Pool (PR2) owns buffers and orchestrates batch operations. Backends (PR3) only do view/reshape.

Naming Convention

Class Purpose
FP4KVCacheQuantMethod ABC base
NVFP4KVMethod Two-level scaling (global FP32 + block FP8)
BlockFP4KVMethod Single-level block scaling (like MXFP4, block=16)
NVFP4KVQuantizeUtil NVFP4 quantize/dequantize kernels
BlockFP4KVQuantizeUtil Block FP4 quantize/dequantize (alias: KVFP4QuantizeUtil)

Test plan

  • pytest test/registered/unit/layers/quantization/test_fp4_kv_cache_quant_method.py -v
  • Verify no existing tests broken (pure additions, backward-compat alias for KVFP4QuantizeUtil)

NVFP4 KV Cache Performance

Benchmark (Qwen3.5-35B-A3B, GSM8K, SM120)

KV Cache MTP Accuracy
FP8 (fp8_e4m3) Yes 96.6%
FP4 (fp4_e2m1) Yes 97.1%

Throughput

FP8 KV Cache NVFP4 KV Cache Speedup (NVFP4 vs FP8)
Prefill Latency (160K) 8757 ms 8792 ms 0.996x
Prefill Latency (1M) 142143 ms 142325 ms 0.998x
Decode Latency (1M) 8.4 ms 7.1 ms 1.18x

Introduce a strategy pattern abstraction for KV cache quantization:
- KVCacheQuantMethod ABC with create_buffers(), quantize_and_store(),
  dequantize_prev_kv(), and scale loading interfaces
- NoneMethod (identity for BF16/FP8), NVFP4Method (two-level scaling
  with per-tensor FP32 + per-block FP8 E4M3), MXFP4Method (single-level)
- KV_CACHE_QUANT_REGISTRY for dtype-to-method mapping

Add NVFP4QuantizeUtil with vectorized CUDA dequantization kernel
(E2M1 LUT-based), FlashInfer fp4_quantize integration, and SM100
quantization kernel. Extend E2M1_VALUES to 16 entries.

This is PR1 of the NVFP4 KV cache refactoring. Pure additions,
no existing code paths are modified.
Tests cover:
- KV_CACHE_QUANT_REGISTRY factory and lookup
- NoneMethod identity behavior
- NVFP4Method buffer shapes, cell size, scale initialization
- NVFP4Method quantize->dequantize roundtrip (CUDA)
- MXFP4Method buffer shapes and roundtrip (CPU)
- KVFP4QuantizeUtil existing MXFP4 roundtrip
- FP4KVCacheRecipe enum values
@github-actions github-actions Bot added quant LLM Quantization blackwell SM100/SM120 labels Apr 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a strategy pattern for KV cache quantization, defining a base KVCacheQuantMethod class and specific implementations for NVFP4 and MXFP4 schemes. It adds NVFP4QuantizeUtil containing CUDA-optimized kernels for quantization and dequantization on SM100/SM120 architectures, and includes unit tests for the new functionality. A review comment suggested marking compute_cell_size as an abstract method in the base class to enforce its implementation in subclasses.

self, head_num: int, head_dim: int, num_layers: int, kv_size: int
) -> int:
"""Per-token memory footprint in bytes (for capacity estimation)."""
raise NotImplementedError
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method compute_cell_size raises NotImplementedError, which is appropriate for an abstract base class, but it should be marked with @AbstractMethod to enforce implementation in subclasses, consistent with other abstract methods in this class.

Suggested change
raise NotImplementedError
@abstractmethod
def compute_cell_size(
self, head_num: int, head_dim: int, num_layers: int, kv_size: int
) -> int:
"""Per-token memory footprint in bytes (for capacity estimation)."""

Remove embedded CUDA kernels (cuda_nvfp4_dequantize, cuda_nvfp4_quantize_blackwell)
and PyTorch fallback quantize (batched_quantize). Keep only:
- fi_nvfp4_quantize: delegates to FlashInfer fp4_quantize
- dequantize: pure PyTorch E2M1 LUT lookup (no JIT CUDA compilation)
Comment thread python/sglang/srt/layers/quantization/kv_cache_quant_method.py Outdated
Comment thread python/sglang/srt/layers/quantization/fp4_kv_cache_quant_method.py
Comment thread python/sglang/srt/layers/quantization/kv_cache_quant_method.py Outdated
Comment thread python/sglang/srt/layers/quantization/kv_cache_quant_method.py Outdated
Comment thread python/sglang/srt/layers/quantization/fp4_kv_cache_quant_method.py
Comment thread python/sglang/srt/layers/quantization/kv_cache_quant_method.py Outdated
Comment thread python/sglang/srt/layers/quantization/kvfp4_tensor.py
Comment thread python/sglang/srt/layers/quantization/kvfp4_tensor.py Outdated
- Remove instruction comments from module docstring
- Mark compute_cell_size as @AbstractMethod, implement in NoneMethod
- Remove GPU names from NVFP4Method docstring
- Rename scale_block_size to SCALE_BLOCK_SIZE (constant naming)
- Add comment explaining max_global_id resize for hybrid models
- Use E2M1_MAX constant instead of hardcoded 6.0 for SM100 scale
- Use flashinfer nvfp4_kv_quantize/nvfp4_kv_dequantize APIs with
  fallback to fp4_quantize and PyTorch dequant for older versions
- Rename fi_nvfp4_quantize -> quantize, keep dequantize
- kv_cache_quant_method.py -> fp4_kv_cache_quant_method.py
- KVCacheQuantMethod -> FP4KVCacheQuantMethod
- KV_CACHE_QUANT_REGISTRY -> FP4_KV_CACHE_QUANT_REGISTRY
- get_kv_cache_quant_method -> get_fp4_kv_cache_quant_method
- Update all test references
- Move test from test/manual/quant/ to test/registered/unit/layers/quantization/
- Add CI registration (register_cpu_ci)
- Use CustomTestCase base class
- Rename test file to match source: test_fp4_kv_cache_quant_method.py
Copy link
Copy Markdown
Contributor Author

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does these method support cuda graph(if they need to support)?

Comment thread python/sglang/srt/layers/quantization/fp4_kv_cache_quant_method.py Outdated
Comment thread test/registered/unit/layers/quantization/test_fp4_kv_cache_quant_method.py Outdated
- Add docstring note that quant methods are prefill-only, not in CUDA graph path
- Clarify dequantize_prev_kv returns FP8 dtype for FlashInfer prefill kernel
- Rename KVFP4QuantizeUtil → BlockFP4KVQuantizeUtil (block-wise FP4, similar to
  MXFP4 but with block_size=16)
- Rename NVFP4QuantizeUtil → NVFP4KVQuantizeUtil (two-level scaling)
- Keep KVFP4QuantizeUtil as backward-compatible alias
- Update fp4_kv_cache_quant_method.py and tests to use new names
- Remove NoneMethod: BF16/FP8 paths should use None, not a dummy FP4 method
- Rename MXFP4Method → BlockFP4Method with name="blockfp4" to match
  BlockFP4KVQuantizeUtil naming and clarify it's not standard MXFP4
- Update registry key from "mxfp4" to "blockfp4"
- Remove NoneMethod tests

Note: These methods are called during prefill (extend) only, not during
CUDA graph capture or decode. Decode reads raw FP4 buffers directly via
the XQA kernel. Therefore CUDA graph compatibility is not required here.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: won't this be incompatible with piecewise CUDA graph then

Comment thread python/sglang/srt/layers/quantization/fp4_kv_cache_quant_method.py
Comment thread python/sglang/srt/layers/quantization/fp4_kv_cache_quant_method.py
Comment thread python/sglang/srt/layers/quantization/fp4_kv_cache_quant_method.py
Comment thread python/sglang/srt/layers/quantization/kvfp4_tensor.py Outdated
Comment thread python/sglang/srt/layers/quantization/kvfp4_tensor.py
@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented Apr 11, 2026

/tag-and-rerun-ci

- Fix misleading CUDA-graph note: operations ARE CUDA-graph compatible
  (FlashInfer kernels + pure tensor ops), not prefill-only as stated
- Improve SM100 E2M1_MAX comment: explain checkpoint convention difference
  (not a hardware difference — FP4 data type is identical on SM100/SM120)
- Remove backward-compat alias KVFP4QuantizeUtil (not public API)
- Simplify NVFP4KVQuantizeUtil: use fp4_quantize directly instead of
  try/except fallback chain with nvfp4_kv_quantize/nvfp4_kv_dequantize
- Dequantize uses pure PyTorch E2M1 LUT (no FlashInfer dependency)
- quantize: nvfp4_kv_quantize (SM100+) with fp4_quantize fallback (SM90),
  assert SM90+ minimum
- dequantize: nvfp4_kv_dequantize (FlashInfer kernel), assert SM100+
- SM100 E2M1_MAX comment: add TRT-LLM reference explaining that calibrated
  KV scales are amax/(6*448) but XQA kernels expect amax/448
  Ref: NVIDIA/TensorRT-LLM FP8QDQLinearMethod
Replace try/except with explicit is_sm100_supported() check to select
between nvfp4_kv_quantize (SM100+) and fp4_quantize (SM90) fallback.
@samuellees
Copy link
Copy Markdown
Contributor Author

samuellees commented Apr 14, 2026

/tag-and-rerun-ci ++++++

@Fridge003 Fridge003 merged commit 73e93be into sgl-project:main Apr 29, 2026
479 of 556 checks passed
"""Block-wise FP4 single-level scaling (similar to MXFP4 but block_size=16)."""

name = "blockfp4"
SCALE_BLOCK_SIZE = 16
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @samuellees , As you know, the block size of OCP MXFP4 is 32. Why is it changed to 16 here? Does the hardware not support the block size=16 format?

Copy link
Copy Markdown
Contributor Author

@samuellees samuellees May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants