[1/4] NVFP4 KV cache: quantization strategy abstraction and kernel#21954
[1/4] NVFP4 KV cache: quantization strategy abstraction and kernel#21954Fridge003 merged 17 commits intosgl-project:mainfrom
Conversation
Introduce a strategy pattern abstraction for KV cache quantization: - KVCacheQuantMethod ABC with create_buffers(), quantize_and_store(), dequantize_prev_kv(), and scale loading interfaces - NoneMethod (identity for BF16/FP8), NVFP4Method (two-level scaling with per-tensor FP32 + per-block FP8 E4M3), MXFP4Method (single-level) - KV_CACHE_QUANT_REGISTRY for dtype-to-method mapping Add NVFP4QuantizeUtil with vectorized CUDA dequantization kernel (E2M1 LUT-based), FlashInfer fp4_quantize integration, and SM100 quantization kernel. Extend E2M1_VALUES to 16 entries. This is PR1 of the NVFP4 KV cache refactoring. Pure additions, no existing code paths are modified.
Tests cover: - KV_CACHE_QUANT_REGISTRY factory and lookup - NoneMethod identity behavior - NVFP4Method buffer shapes, cell size, scale initialization - NVFP4Method quantize->dequantize roundtrip (CUDA) - MXFP4Method buffer shapes and roundtrip (CPU) - KVFP4QuantizeUtil existing MXFP4 roundtrip - FP4KVCacheRecipe enum values
There was a problem hiding this comment.
Code Review
This pull request introduces a strategy pattern for KV cache quantization, defining a base KVCacheQuantMethod class and specific implementations for NVFP4 and MXFP4 schemes. It adds NVFP4QuantizeUtil containing CUDA-optimized kernels for quantization and dequantization on SM100/SM120 architectures, and includes unit tests for the new functionality. A review comment suggested marking compute_cell_size as an abstract method in the base class to enforce its implementation in subclasses.
| self, head_num: int, head_dim: int, num_layers: int, kv_size: int | ||
| ) -> int: | ||
| """Per-token memory footprint in bytes (for capacity estimation).""" | ||
| raise NotImplementedError |
There was a problem hiding this comment.
The method compute_cell_size raises NotImplementedError, which is appropriate for an abstract base class, but it should be marked with @AbstractMethod to enforce implementation in subclasses, consistent with other abstract methods in this class.
| raise NotImplementedError | |
| @abstractmethod | |
| def compute_cell_size( | |
| self, head_num: int, head_dim: int, num_layers: int, kv_size: int | |
| ) -> int: | |
| """Per-token memory footprint in bytes (for capacity estimation).""" |
Remove embedded CUDA kernels (cuda_nvfp4_dequantize, cuda_nvfp4_quantize_blackwell) and PyTorch fallback quantize (batched_quantize). Keep only: - fi_nvfp4_quantize: delegates to FlashInfer fp4_quantize - dequantize: pure PyTorch E2M1 LUT lookup (no JIT CUDA compilation)
- Remove instruction comments from module docstring - Mark compute_cell_size as @AbstractMethod, implement in NoneMethod - Remove GPU names from NVFP4Method docstring - Rename scale_block_size to SCALE_BLOCK_SIZE (constant naming) - Add comment explaining max_global_id resize for hybrid models - Use E2M1_MAX constant instead of hardcoded 6.0 for SM100 scale - Use flashinfer nvfp4_kv_quantize/nvfp4_kv_dequantize APIs with fallback to fp4_quantize and PyTorch dequant for older versions - Rename fi_nvfp4_quantize -> quantize, keep dequantize
- kv_cache_quant_method.py -> fp4_kv_cache_quant_method.py - KVCacheQuantMethod -> FP4KVCacheQuantMethod - KV_CACHE_QUANT_REGISTRY -> FP4_KV_CACHE_QUANT_REGISTRY - get_kv_cache_quant_method -> get_fp4_kv_cache_quant_method - Update all test references
- Move test from test/manual/quant/ to test/registered/unit/layers/quantization/ - Add CI registration (register_cpu_ci) - Use CustomTestCase base class - Rename test file to match source: test_fp4_kv_cache_quant_method.py
samuellees
left a comment
There was a problem hiding this comment.
Does these method support cuda graph(if they need to support)?
- Add docstring note that quant methods are prefill-only, not in CUDA graph path - Clarify dequantize_prev_kv returns FP8 dtype for FlashInfer prefill kernel - Rename KVFP4QuantizeUtil → BlockFP4KVQuantizeUtil (block-wise FP4, similar to MXFP4 but with block_size=16) - Rename NVFP4QuantizeUtil → NVFP4KVQuantizeUtil (two-level scaling) - Keep KVFP4QuantizeUtil as backward-compatible alias - Update fp4_kv_cache_quant_method.py and tests to use new names
- Remove NoneMethod: BF16/FP8 paths should use None, not a dummy FP4 method - Rename MXFP4Method → BlockFP4Method with name="blockfp4" to match BlockFP4KVQuantizeUtil naming and clarify it's not standard MXFP4 - Update registry key from "mxfp4" to "blockfp4" - Remove NoneMethod tests
|
|
||
| Note: These methods are called during prefill (extend) only, not during | ||
| CUDA graph capture or decode. Decode reads raw FP4 buffers directly via | ||
| the XQA kernel. Therefore CUDA graph compatibility is not required here. |
There was a problem hiding this comment.
Question: won't this be incompatible with piecewise CUDA graph then
|
/tag-and-rerun-ci |
- Fix misleading CUDA-graph note: operations ARE CUDA-graph compatible (FlashInfer kernels + pure tensor ops), not prefill-only as stated - Improve SM100 E2M1_MAX comment: explain checkpoint convention difference (not a hardware difference — FP4 data type is identical on SM100/SM120) - Remove backward-compat alias KVFP4QuantizeUtil (not public API) - Simplify NVFP4KVQuantizeUtil: use fp4_quantize directly instead of try/except fallback chain with nvfp4_kv_quantize/nvfp4_kv_dequantize - Dequantize uses pure PyTorch E2M1 LUT (no FlashInfer dependency)
- quantize: nvfp4_kv_quantize (SM100+) with fp4_quantize fallback (SM90), assert SM90+ minimum - dequantize: nvfp4_kv_dequantize (FlashInfer kernel), assert SM100+ - SM100 E2M1_MAX comment: add TRT-LLM reference explaining that calibrated KV scales are amax/(6*448) but XQA kernels expect amax/448 Ref: NVIDIA/TensorRT-LLM FP8QDQLinearMethod
Replace try/except with explicit is_sm100_supported() check to select between nvfp4_kv_quantize (SM100+) and fp4_quantize (SM90) fallback.
…yTorch LUT (SM90)
|
/tag-and-rerun-ci ++++++ |
| """Block-wise FP4 single-level scaling (similar to MXFP4 but block_size=16).""" | ||
|
|
||
| name = "blockfp4" | ||
| SCALE_BLOCK_SIZE = 16 |
There was a problem hiding this comment.
Hi, @samuellees , As you know, the block size of OCP MXFP4 is 32. Why is it changed to 16 here? Does the hardware not support the block size=16 format?
There was a problem hiding this comment.
Hi @DehuaTang , good question but I have no answer for that. Your question is not in the scope of this PR, please see here https://github.com/sgl-project/sglang/pull/21954/changes/BASE..942c7d6f54a56197435cea46003301a2c53fc833#diff-87c2ca9d8198ad03e3239bfef35ed25550321d88a3e97bd5820115743cd7d0a4L46
Summary
Part 1 of 4 for NVFP4 KV cache support on SM120 GPUs. Split from #21601 for easier review.
Roadmap: PR1 (this) → PR2 (memory pool) → PR3 (attention backends) → PR4 (MTP + config + docs)
This PR introduces the FP4 KV cache quantization strategy abstraction and quantize/dequantize utilities. Pure additions — no existing code paths are modified.
Changes
fp4_kv_cache_quant_method.py(NEW):FP4KVCacheQuantMethodABC with strategy patternNVFP4KVMethod— two-level scaling (per-tensor FP32 + per-block FP8 E4M3), SM100/SM120BlockFP4KVMethod— block-wise single-level scaling (similar to MXFP4 but block_size=16)FP4_KV_CACHE_QUANT_REGISTRY— maps recipe names to method classeskvfp4_tensor.py(extended):NVFP4KVQuantizeUtil— quantize via FlashInfernvfp4_kv_quantize(fallbackfp4_quantize), dequantize vianvfp4_kv_dequantize(fallback PyTorch E2M1 LUT)BlockFP4KVQuantizeUtil— block-wise FP4 with block_size=16 (renamed fromKVFP4QuantizeUtil, backward-compat alias kept)Unit tests: Registry, buffer shapes, quantize→dequantize roundtrip for NVFP4 (CUDA) and BlockFP4 (CPU)
Design
FP4KVCacheQuantMethod (pure compute) → Pool (buffer + batch dequant) → Backend (view adaptation)
The quant method owns quantize/dequantize logic. The Pool (PR2) owns buffers and orchestrates batch operations. Backends (PR3) only do view/reshape.
Naming Convention
FP4KVCacheQuantMethodNVFP4KVMethodBlockFP4KVMethodNVFP4KVQuantizeUtilBlockFP4KVQuantizeUtilKVFP4QuantizeUtil)Test plan
pytest test/registered/unit/layers/quantization/test_fp4_kv_cache_quant_method.py -vKVFP4QuantizeUtil)NVFP4 KV Cache Performance
Benchmark (Qwen3.5-35B-A3B, GSM8K, SM120)
Throughput