Skip to content

support fp8 kvcache for hybrid attn backend on GPT-OSS#9783

Merged
rainj-me merged 10 commits intosgl-project:mainfrom
bytedance-iaas:dev/gpt-oss-fp8-kv-cache
Sep 1, 2025
Merged

support fp8 kvcache for hybrid attn backend on GPT-OSS#9783
rainj-me merged 10 commits intosgl-project:mainfrom
bytedance-iaas:dev/gpt-oss-fp8-kv-cache

Conversation

@rainj-me
Copy link
Copy Markdown
Collaborator

@rainj-me rainj-me commented Aug 29, 2025

Motivation

#9782

On B200/GB200 kv cache volume actually blocks the batch size which is the bottleneck for the GPT-OSS performance. And the trtllm-mha cuda kernel could not support Q(bf16), KV(fp8), O(bf16) in prefill phase. That's why I create this PR to use hybrid attn backend (prefill: triton, decode: trtllm-mha) for the fp8 kv cache support. Will continue to seek solution for the FP8 kv cache support on trtllm-mha kernel during prefill phase.

Modifications

  • Update the _enable_fused_set_kv_buffer only enable fuse set kv buffer when the kv cache pool dtype is bfloat16

Accuracy Tests

Testing scripts

  • GPQA high reasoning
OPENAI_API_KEY=dummy python -m gpt_oss.evals --base-url http://127.0.0.1:28000/v1 --model /data01/models/gpt-oss-120b --reasoning-effort high --eval gpqa --n-threads 1000
  • AIME25 high reasoning
OPENAI_API_KEY=dummy python -m gpt_oss.evals --base-url http://127.0.0.1:28000/v1 --model /data01/models/gpt-oss-120b --reasoning-effort high --eval aime25 --n-threads 1000

triton with bf16 kv cache

[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_030309', 'metric': 0.7803030303030303}]
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_173633', 'metric': 0.9125}]

trtllm-mha with bf16 kv cache

[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_045631', 'metric': 0.8011363636363636}]
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_165807', 'metric': 0.9041666666666667}]

hybrid attn backend with fp8 kv cache

[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_062255', 'metric': 0.7897727272727273}]
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_152103', 'metric': 0.8875}]

Benchmarking and Profiling

With fp8 kv cache, in single B200 GPU, the batch size can run 768 without any queuing request, which is way more than 630 compare to bfloat16 kv cache.

Hybrid attn backend with fp8 kv cache result

Burstiness factor: 1.0 (Poisson process)
Maximum request concurrency: 768
100%|█████████████████████████████████████████| 1536/1536 [02:50<00:00,  8.98it/s]
============ Serving Benchmark Result ============
Successful requests:                     1536
Maximum request concurrency:             768
Request rate configured (RPS):           15.00
Benchmark duration (s):                  170.98
Total input tokens:                      5373013
Total generated tokens:                  2304000
Request throughput (req/s):              8.98
Request goodput (req/s):                 4.84
Output token throughput (tok/s):         13475.56
Total Token throughput (tok/s):          44901.07
---------------Time to First Token----------------
Mean TTFT (ms):                          837.91
Median TTFT (ms):                        368.59
P50 TTFT (ms):                           368.59
P90 TTFT (ms):                           1533.57
P95 TTFT (ms):                           5856.49
P99 TTFT (ms):                           6346.45
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          46.27
Median TPOT (ms):                        49.28
P50 TPOT (ms):                           49.28
P90 TPOT (ms):                           52.58
P95 TPOT (ms):                           52.63
P99 TPOT (ms):                           52.69
---------------Inter-token Latency----------------
Mean ITL (ms):                           46.27
Median ITL (ms):                         28.54
P50 ITL (ms):                            28.54
P90 ITL (ms):                            48.34
P95 ITL (ms):                            111.59
P99 ITL (ms):                            368.92
==================================================

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @rainj-me, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the system's ability to handle larger batch sizes and improve GPU utilization by integrating FP8 KV cache support. It achieves this through a strategic shift to a hybrid attention mechanism, addressing current hardware and kernel limitations to unlock greater throughput and efficiency in large language model inference.

Highlights

  • FP8 KV Cache Support: This pull request introduces support for FP8 KV cache by implementing a hybrid attention backend. This is crucial for optimizing performance on B200/GB200 GPUs where KV cache volume can bottleneck batch size, especially since the trtllm-mha CUDA kernel does not support Q(bf16), KV(fp8), O(bf16) during the prefill phase.
  • Hybrid Attention Backend: To circumvent the trtllm-mha limitation, the solution employs a hybrid attention backend: Triton is used for the prefill phase, and trtllm-mha is used for the decode phase. This allows for the efficient utilization of FP8 KV cache.
  • Conditional Fused KV Buffer Enabling: The _enable_fused_set_kv_buffer function has been updated to conditionally enable fused KV buffer operations. It now only activates when the device is CUDA and the KV cache pool data type is bfloat16, ensuring compatibility and optimal performance based on the data type.
  • Performance Improvements: Benchmarking results demonstrate significant performance gains. With FP8 KV cache, a single B200 GPU can handle a batch size of 768 without queuing requests, a substantial improvement over the 630 batch size achievable with bfloat16 KV cache.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for fp8 KV cache in the hybrid attention backend by conditionally disabling the fused set_kv_buffer kernel. The logic is updated in _enable_fused_set_kv_buffer to check if the KV cache dtype is bfloat16. The changes are clear, correct, and well-contained. My feedback includes a minor suggestion to improve code documentation for better maintainability.

Comment thread python/sglang/srt/models/gpt_oss.py Outdated
rainj-me and others added 2 commits August 28, 2025 21:51
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@Swipe4057
Copy link
Copy Markdown
Contributor

will this work on H100?

@rainj-me
Copy link
Copy Markdown
Collaborator Author

rainj-me commented Aug 29, 2025

will this work on H100?

With hopper, it may only work for triton (maybe fa3) backend since trtllm-mha only works on sm100.

@rainj-me rainj-me changed the title support fp8 kvcache for hybrid attn backend support fp8 kvcache for hybrid attn backend and trtllm_mha backend only Aug 29, 2025
@Fridge003
Copy link
Copy Markdown
Collaborator

Fridge003 commented Aug 30, 2025

@rainj-me Can you please post the accuracy result of gpqa dataset?

A sample command could be

OPENAI_API_KEY=dummy python -m gpt_oss.evals --base-url http://localhost:40010/v1 --model dummy --reasoning-effort medium --eval gpqa --n-threads 1000

@rainj-me
Copy link
Copy Markdown
Collaborator Author

rainj-me commented Aug 30, 2025

@rainj-me Can you please post the accuracy result of gpqa dataset?

A sample command could be

OPENAI_API_KEY=dummy python -m gpt_oss.evals --base-url http://localhost:40010/v1 --model dummy --reasoning-effort medium --eval gpqa --n-threads 1000

Model: GPT-OSS 120B
Test scritpt:

OPENAI_API_KEY=dummy python -m gpt_oss.evals --base-url http://localhost:28000/v1 --model /data01/models/gpt-oss-120b --reasoning-effort medium --eval gpqa --n-threads 1000
  • Baseline: bf16 KV Cache with trtllm-mha attn backend
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_024839', 'metric': 0.711489898989899}]
  • Fp8 KV Cache with hybrid attn backend (prefill: triton, decode: trtllm-mha)
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_025437', 'metric': 0.6559343434343434}]
  • Fp8 KV Cache with trtllm-mha attn backend
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_023139', 'metric': 0.398989898989899}]

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

  • Fp8 KV Cache with trtllm-mha attn backend
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_025437', 'metric': 0.6559343434343434}]
  • Fp8 KV Cache with trtllm-mha attn backend
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_023139', 'metric': 0.398989898989899}]

What's the difference between the last two cases? @rainj-me

@rainj-me
Copy link
Copy Markdown
Collaborator Author

  • Fp8 KV Cache with trtllm-mha attn backend
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_025437', 'metric': 0.6559343434343434}]
  • Fp8 KV Cache with trtllm-mha attn backend
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-medium_temp1.0_20250830_023139', 'metric': 0.398989898989899}]

What's the difference between the last two cases? @rainj-me

Sorry, I forgot to update the title of the second one with 0.6559 score,which is the hybrid attn backend, prefill (triton attn backend), decode (trtllm-mha attn backend). Let me update the title.

@Fridge003
Copy link
Copy Markdown
Collaborator

@rainj-me Seems accuracy on fp8 kv cache with trtllm-mha backend will drop a lot. Probably due to the modification of this PR?

@rainj-me
Copy link
Copy Markdown
Collaborator Author

rainj-me commented Aug 31, 2025

@rainj-me Seems accuracy on fp8 kv cache with trtllm-mha backend will drop a lot. Probably due to the modification of this PR?

Yes, it due to the trtllm-mha context ops only have QKV(fp8)O(bf16) and QKVO(bf16) in prefill phase, so I have to quantize the Query to FP8 with trtllm-mha which lead to accuracy significantly dropped. I'm trying to ask a Q(bf16) KV(FP8) O(FP16) context ops in flashinfer community. How about I revert the query quantization and keep the hybrid attn backend one ?

@rainj-me
Copy link
Copy Markdown
Collaborator Author

@rainj-me Seems accuracy on fp8 kv cache with trtllm-mha backend will drop a lot. Probably due to the modification of this PR?

I revert the commit of the trtllm-mha fp8 kv cache commit, however still keep the hybrid attn backend change. Let's track the trtllm-mha attn backend in the issue

#9782

@Fridge003
Copy link
Copy Markdown
Collaborator

@rainj-me Seems accuracy on fp8 kv cache with trtllm-mha backend will drop a lot. Probably due to the modification of this PR?

Yes, it due to the trtllm-mha context ops only have QKV(fp8)O(bf16) and QKVO(bf16) in prefill phase, so I have to quantize the Query to FP8 with trtllm-mha which lead to accuracy significantly dropped. I'm trying to ask a Q(bf16) KV(FP8) O(FP16) context ops in flashinfer community. How about I revert the query quantization and keep the hybrid attn backend one ?

Sure, you can support fp8 for hybrid backend in this PR, and leave trtllm-mha for future work

@hlu1
Copy link
Copy Markdown
Collaborator

hlu1 commented Aug 31, 2025

GPQA dropping from 0.71 to 0.66 for medium reasoning is significant. Could you run with reasoning high for both GPQA and AMIE25? High could be more sensitive.

@rainj-me
Copy link
Copy Markdown
Collaborator Author

rainj-me commented Sep 1, 2025

GPQA dropping from 0.71 to 0.66 for medium reasoning is significant. Could you run with reasoning high for both GPQA and AMIE25? High could be more sensitive.

Sure, I will run the tests with

  1. Triton only with BF16 kv cache
  2. Trtllm-mha only with BF16 cache
  3. Hybrid attn backend with FP8 kv cache

And the evals includes

  1. GPQA
  2. AIME 25

@rainj-me
Copy link
Copy Markdown
Collaborator Author

rainj-me commented Sep 1, 2025

GPQA dropping from 0.71 to 0.66 for medium reasoning is significant. Could you run with reasoning high for both GPQA and AMIE25? High could be more sensitive.

For GPQA high reasoning, the results are

  • Triton attn backend only with BF16 kv cache
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_030309', 'metric': 0.7803030303030303}]
  • Trtllm-mha attn backend only with BF16 kv cache
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_045631', 'metric': 0.8011363636363636}]
  • Hybrid attn backend (prefill: triton, decode trtllm-mha) with FP8 kv cache
[{'eval_name': 'gpqa', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_062255', 'metric': 0.7897727272727273}]

Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Fridge003 Fridge003 changed the title support fp8 kvcache for hybrid attn backend and trtllm_mha backend only support fp8 kvcache for hybrid attn backend on GPT-OSS Sep 1, 2025
@rainj-me
Copy link
Copy Markdown
Collaborator Author

rainj-me commented Sep 1, 2025

GPQA dropping from 0.71 to 0.66 for medium reasoning is significant. Could you run with reasoning high for both GPQA and AMIE25? High could be more sensitive.

Sure, I will run the tests with

  1. Triton only with BF16 kv cache
  2. Trtllm-mha only with BF16 cache
  3. Hybrid attn backend with FP8 kv cache

And the evals includes

  1. GPQA
  2. AIME 25

For AIME25 high reasoning, the results are

  • Triton attn backend only with BF16 kv cache
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_173633', 'metric': 0.9125}]
  • Trtllm-mha attn backend only with BF16 kv cache
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_165807', 'metric': 0.9041666666666667}]
  • Hybrid attn backend (prefill: triton, decode trtllm-mha) with FP8 kv cache
[{'eval_name': 'aime25', 'model_name': '__data01__models__gpt-oss-120b-high_temp1.0_20250901_152103', 'metric': 0.8875}]

@rainj-me rainj-me enabled auto-merge (squash) September 1, 2025 18:46
@rainj-me rainj-me merged commit 9db8025 into sgl-project:main Sep 1, 2025
62 of 71 checks passed
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants