feat: Add FP4 (E2M1) KV Cache Support for MHA#12612
feat: Add FP4 (E2M1) KV Cache Support for MHA#12612Fridge003 merged 1 commit intosgl-project:mainfrom
Conversation
Based on PR sgl-project#10078, this patch - introduces FP4 KV cache support in MHATokenToKVPool with uint8 storage. - adds k_scale_buffer and v_scale_buffer to store FP4 scaling factors. - implements batched quantization on cache update and dequantization on access. - updates ModelRunner memory estimation to account for FP4 scale buffers. - maintains backward compatibility with FP16/FP8 KV cache. Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com> Co-authored-by: Yichen Wang <yichen.wang@bytedance.com>
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Hi @Fridge003 @AniZpZ @zhyncs, |
|
Hi @zhyncs @AniZpZ @Fridge003, |
| ) | ||
| for _ in range(self.layer_num) | ||
| ] | ||
| if is_float4_e2m1fn_x2(self.dtype): |
There was a problem hiding this comment.
Is it possible to overload the MHATokenToKVPool with a MHATokenToKVPoolFP4?
There are too many if-else branches here.
I feel the same change needs to be applied to MLA FP4 pool
There was a problem hiding this comment.
@Fridge003 It’s feasible. How about this — since I have to update the MLA and submit a new PR anyway, let me fix this issue in that PR as well. If you agree with this, I’ll start working on the PR.
There was a problem hiding this comment.
Yes, let's open a new PR for it
There was a problem hiding this comment.
@Fridge003 Great! I'll create a new PR for the code refactoring for both MLA&MHA token pool fp4.
Meanwhile, could you please help me merge this PR first? This one is for the functionality. I think it’s better to separate the code refactoring from the new features. Thanks~
There was a problem hiding this comment.
New PR #13547 has been submitted to refactor the FP4 token pools for both MLA and MHA.
|
Hi @Fridge003, |
Summary
This PR introduces support for FP4 (float4_e2m1fn_x2) KV caching in Multi-Headed Attention (MHA) e.g., Qwen and GPT-OSS . See #10083, points 1-2, for more context.
Co-authored-by: @yicwang Yichen Wang yichen.wang@bytedance.com
Usage
Motivation and Benefits
Large models often face GPU memory constraints when storing KV cache.
By introducing FP4 quantization with scale buffers, this PR significantly reduces KV memory usage and improves efficiency:
Key Changes
Accuracy tests for KV4 MHA
Qwen3-235B-A22B
gpt-oss-120b
Observation:
Performance Results
Although speed is not the main goal of this PR (will be addressed in #10083 3-2), we ran throughput tests using
torch_nativeto provide reference:Reason for
torch_native:- Other backends (e.g.,
trtllm_mha, Triton attention) have fused kernels for FP8 only, making FP8 faster there.- KV8 lacks a fused kernel on
torch_native, so both KV4 and KV8 are measured on the same backend.Test configuration:
-
--num-prompts: 100–400-
--max-concurrency: 50–200- Unit: Output token throughput (tok/s)
Observation:
Checklist