support fp8 kvcache for hybrid attn backend on GPT-OSS#9783
support fp8 kvcache for hybrid attn backend on GPT-OSS#9783rainj-me merged 10 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @rainj-me, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the system's ability to handle larger batch sizes and improve GPU utilization by integrating FP8 KV cache support. It achieves this through a strategic shift to a hybrid attention mechanism, addressing current hardware and kernel limitations to unlock greater throughput and efficiency in large language model inference.
Highlights
- FP8 KV Cache Support: This pull request introduces support for FP8 KV cache by implementing a hybrid attention backend. This is crucial for optimizing performance on B200/GB200 GPUs where KV cache volume can bottleneck batch size, especially since the
trtllm-mhaCUDA kernel does not support Q(bf16), KV(fp8), O(bf16) during the prefill phase. - Hybrid Attention Backend: To circumvent the
trtllm-mhalimitation, the solution employs a hybrid attention backend: Triton is used for the prefill phase, andtrtllm-mhais used for the decode phase. This allows for the efficient utilization of FP8 KV cache. - Conditional Fused KV Buffer Enabling: The
_enable_fused_set_kv_bufferfunction has been updated to conditionally enable fused KV buffer operations. It now only activates when the device is CUDA and the KV cache pool data type isbfloat16, ensuring compatibility and optimal performance based on the data type. - Performance Improvements: Benchmarking results demonstrate significant performance gains. With FP8 KV cache, a single B200 GPU can handle a batch size of 768 without queuing requests, a substantial improvement over the 630 batch size achievable with
bfloat16KV cache.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds support for fp8 KV cache in the hybrid attention backend by conditionally disabling the fused set_kv_buffer kernel. The logic is updated in _enable_fused_set_kv_buffer to check if the KV cache dtype is bfloat16. The changes are clear, correct, and well-contained. My feedback includes a minor suggestion to improve code documentation for better maintainability.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
will this work on H100? |
With hopper, it may only work for triton (maybe fa3) backend since trtllm-mha only works on sm100. |
|
@rainj-me Can you please post the accuracy result of gpqa dataset? A sample command could be |
Model: GPT-OSS 120B OPENAI_API_KEY=dummy python -m gpt_oss.evals --base-url http://localhost:28000/v1 --model /data01/models/gpt-oss-120b --reasoning-effort medium --eval gpqa --n-threads 1000
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_024839', 'metric': 0.711489898989899}]
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_025437', 'metric': 0.6559343434343434}]
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_023139', 'metric': 0.398989898989899}] |
What's the difference between the last two cases? @rainj-me |
Sorry, I forgot to update the title of the second one with 0.6559 score,which is the hybrid attn backend, prefill (triton attn backend), decode (trtllm-mha attn backend). Let me update the title. |
|
@rainj-me Seems accuracy on fp8 kv cache with trtllm-mha backend will drop a lot. Probably due to the modification of this PR? |
Yes, it due to the trtllm-mha context ops only have QKV(fp8)O(bf16) and QKVO(bf16) in prefill phase, so I have to quantize the Query to FP8 with trtllm-mha which lead to accuracy significantly dropped. I'm trying to ask a Q(bf16) KV(FP8) O(FP16) context ops in flashinfer community. How about I revert the query quantization and keep the hybrid attn backend one ? |
This reverts commit 22c204c.
Sure, you can support fp8 for hybrid backend in this PR, and leave trtllm-mha for future work |
|
GPQA dropping from 0.71 to 0.66 for medium reasoning is significant. Could you run with reasoning high for both GPQA and AMIE25? High could be more sensitive. |
Sure, I will run the tests with
And the evals includes
|
For GPQA high reasoning, the results are
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_030309', 'metric': 0.7803030303030303}]
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_045631', 'metric': 0.8011363636363636}]
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_062255', 'metric': 0.7897727272727273}] |
For AIME25 high reasoning, the results are
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_173633', 'metric': 0.9125}]
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_165807', 'metric': 0.9041666666666667}]
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_152103', 'metric': 0.8875}] |
Motivation
#9782
On B200/GB200 kv cache volume actually blocks the batch size which is the bottleneck for the GPT-OSS performance. And the trtllm-mha cuda kernel could not support Q(bf16), KV(fp8), O(bf16) in prefill phase. That's why I create this PR to use hybrid attn backend (prefill: triton, decode: trtllm-mha) for the fp8 kv cache support. Will continue to seek solution for the FP8 kv cache support on trtllm-mha kernel during prefill phase.
Modifications
Accuracy Tests
Testing scripts
triton with bf16 kv cache
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_030309', 'metric': 0.7803030303030303}] [{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_173633', 'metric': 0.9125}]trtllm-mha with bf16 kv cache
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_045631', 'metric': 0.8011363636363636}] [{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_165807', 'metric': 0.9041666666666667}]hybrid attn backend with fp8 kv cache
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_062255', 'metric': 0.7897727272727273}] [{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_152103', 'metric': 0.8875}]Benchmarking and Profiling
With fp8 kv cache, in single B200 GPU, the batch size can run 768 without any queuing request, which is way more than 630 compare to bfloat16 kv cache.
Hybrid attn backend with fp8 kv cache result
Checklist