Skip to content

feat: Add FP4 (E2M1) KV Cache Support for MHA#12612

Merged
Fridge003 merged 1 commit intosgl-project:mainfrom
bytedance-iaas:horenc/kv4_mha_on_main_release
Nov 15, 2025
Merged

feat: Add FP4 (E2M1) KV Cache Support for MHA#12612
Fridge003 merged 1 commit intosgl-project:mainfrom
bytedance-iaas:horenc/kv4_mha_on_main_release

Conversation

@JackChuang
Copy link
Copy Markdown
Contributor

Summary

This PR introduces support for FP4 (float4_e2m1fn_x2) KV caching in Multi-Headed Attention (MHA) e.g., Qwen and GPT-OSS . See #10083, points 1-2, for more context.

Co-authored-by: @yicwang Yichen Wang yichen.wang@bytedance.com

Usage

$ python3 -m sglang.launch_server --kv-cache-dtype fp4_e2m1 ... 

Motivation and Benefits

Large models often face GPU memory constraints when storing KV cache.
By introducing FP4 quantization with scale buffers, this PR significantly reduces KV memory usage and improves efficiency:

  • Supports significantly more tokens than KV8 (≈1.78×) and KV16 (≈3.56×) due to FP4 quantization with block_size = 16.
  • Improves scalability for longer context windows and throughput for large batch requests
  • Enables inference of larger models or longer context windows on memory-limited GPUs.
  • Seamless integration with existing inference pipelines without breaking KV16/KV8 workflows.

Key Changes

  • MHATokenToKVPool
    • Added FP4 KV cache support with uint8 storage format.
    • Introduced k_scale_buffer and v_scale_buffer for per-block scaling factors.
    • Integrated batched quantization (on update) and dequantization (on access) using KVFP4QuantizeUtil.
  • ModelRunner
    • Updated GPU memory estimation logic to account for FP4 cache and scale buffers.
  • Compatibility
    • Preserves existing FP16/FP8 KV cache behavior without changes.

Accuracy tests for KV4 MHA

  • FP4 KV cache is well-suited for large-scale models, providing memory savings with minimal accuracy impact.
  • For smaller models, careful evaluation is needed to balance memory efficiency and accuracy.

Qwen3-235B-A22B

Model Dataset Metric Subset Num Score Cat.0
KV4 (fp4_e2m1)
KV4 gsm8k mean_acc main 6595 0.9186 default
KV4 aime25 mean_acc OVERALL 150 0.6 -
KV4 gpqa_diamond mean_acc default 990 0.6778 default
KV8 (fp8_e4m3)
KV8 gsm8k mean_acc main 6595 0.9181 default
KV8 aime25 mean_acc OVERALL 150 0.7333 -
KV8 gpqa_diamond mean_acc default 990 0.6899 default
KV16
KV16 gsm8k mean_acc main 6595 0.9168 default
KV16 aime25 mean_acc OVERALL 150 0.7733 -
KV16 gpqa_diamond mean_acc default 990 0.701 default

gpt-oss-120b

Model Dataset Metric Subset Num Score Cat.0
KV4 (fp4_e2m1)
KV4 aime25 mean_acc OVERALL 150 0.3533 -
KV4 gsm8k mean_acc main 6595 0.9152 default
KV4 gpqa_diamond mean_acc default 990 0.3202 default
KV8 (fp8_e4m3)
KV8 aime25 mean_acc OVERALL 150 0.7667 -
KV8 gsm8k mean_acc main 6595 0.9163 default
KV8 gpqa_diamond mean_acc default 990 0.5434 default
KV16
KV16 aime25 mean_acc OVERALL 150 0.7533 -
KV16 gsm8k mean_acc main 6595 0.9161 default
KV16 gpqa_diamond mean_acc default 990 0.5081 default

Observation:  

  • On large models (Qwen3-235B-A22B), FP4 maintains accuracy close to FP8/FP16.  
  • On smaller models (gpt-oss-120b), FP4 shows more pronounced accuracy drops on difficult datasets.  
  • Trend: Accuracy degradation is more significant in smaller models.

Performance Results

Although speed is not the main goal of this PR (will be addressed in #10083 3-2), we ran throughput tests using torch_native to provide reference:

  • Reason for torch_native:  
      - Other backends (e.g., trtllm_mha, Triton attention) have fused kernels for FP8 only, making FP8 faster there.
      - KV8 lacks a fused kernel on torch_native, so both KV4 and KV8 are measured on the same backend.  

    Note:  KV8 could not run when the attention backend was set to torch_native. We have fixed this problem in PR Support kv8 (FP8) with torch_native attention backend #12596

  • Test configuration:  
      - --num-prompts: 100–400  
      - --max-concurrency: 50–200  
      - Unit: Output token throughput (tok/s)

Num Prompts Concurrency KV8 (tok/s) KV4 (tok/s) Gain TTFT (ms) TPOT (ms)
100 50  62.43 60.35 -3.33% 5323  798 
200 100 67.34 68.02 +1.0%  9378  1480
300 150 68.81 71.63 +4.1%  13500 2172
400 200 69.75 74.19 +6.36% 19595 2685

Observation:  

Checklist

Based on PR sgl-project#10078, this patch
- introduces FP4 KV cache support in MHATokenToKVPool with uint8 storage.
- adds k_scale_buffer and v_scale_buffer to store FP4 scaling factors.
- implements batched quantization on cache update and dequantization on access.
- updates ModelRunner memory estimation to account for FP4 scale buffers.
- maintains backward compatibility with FP16/FP8 KV cache.

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
Co-authored-by: Yichen Wang <yichen.wang@bytedance.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@JackChuang
Copy link
Copy Markdown
Contributor Author

Hi @Fridge003 @AniZpZ @zhyncs,
Thank you very much for helping review and merge the PR for MLA KV4 (#10078).
Could you please help review this PR for MLA KV4? Thank you!

@rainj-me rainj-me added the run-ci label Nov 7, 2025
@JackChuang
Copy link
Copy Markdown
Contributor Author

Hi @zhyncs @AniZpZ @Fridge003,
Would really appreciate it if someone could take a quick look at this PR when you have a moment. Thanks!

Comment thread python/sglang/srt/model_executor/model_runner.py
)
for _ in range(self.layer_num)
]
if is_float4_e2m1fn_x2(self.dtype):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to overload the MHATokenToKVPool with a MHATokenToKVPoolFP4?
There are too many if-else branches here.
I feel the same change needs to be applied to MLA FP4 pool

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 It’s feasible. How about this — since I have to update the MLA and submit a new PR anyway, let me fix this issue in that PR as well. If you agree with this, I’ll start working on the PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's open a new PR for it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 Great! I'll create a new PR for the code refactoring for both MLA&MHA token pool fp4.
Meanwhile, could you please help me merge this PR first? This one is for the functionality. I think it’s better to separate the code refactoring from the new features. Thanks~

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New PR #13547 has been submitted to refactor the FP4 token pools for both MLA and MHA.

@JackChuang
Copy link
Copy Markdown
Contributor Author

Hi @Fridge003,
Per our discussion above, could you please help approve and merge this PR so that I can work directly on the main branch for the code refactoring of both MHA and MLA token pool fp4? Thank you~

@Fridge003
Copy link
Copy Markdown
Collaborator

Fridge003 commented Nov 15, 2025

NV tests all passed
https://github.com/sgl-project/sglang/actions/runs/19177871109/job/55448358397?pr=12612

@Fridge003 Fridge003 merged commit 6d5e16f into sgl-project:main Nov 15, 2025
141 of 161 checks passed
@HanHan009527 HanHan009527 deleted the horenc/kv4_mha_on_main_release branch December 16, 2025 16:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants