Skip to content

Support kv8 (FP8) with torch_native attention backend#12596

Merged
Fridge003 merged 2 commits intosgl-project:mainfrom
bytedance-iaas:horenc/torch_native_kv8_support_on_main_release
Dec 28, 2025
Merged

Support kv8 (FP8) with torch_native attention backend#12596
Fridge003 merged 2 commits intosgl-project:mainfrom
bytedance-iaas:horenc/torch_native_kv8_support_on_main_release

Conversation

@JackChuang
Copy link
Copy Markdown
Contributor

@JackChuang JackChuang commented Nov 4, 2025

This patch fixes the issue where KV8 could not run when the attention backend was set to torch_native.

Motivation

Currently, when using --attention-backend torch_native, the --kv-cache-dtype fp8_e4m3 option is not supported, causing KV cache in FP8 to fail. This patch fixes the issue by ensuring that the query, key, and value tensors are cast to the same dtype before calling scaled_dot_product_attention.

Modifications

  • Modified TorchNativeAttnBackend in torch_native_backend.py
  • Added dtype casting for per_req_key and per_req_value to match per_req_query
  • Ensures scaled_dot_product_attention works correctly with FP8 KV cache

Accuracy Tests

Tested in another PR #12612

Benchmarking and Profiling

Tested in another PR #12612

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Fridge003
Copy link
Copy Markdown
Collaborator

@JackChuang Please update this doc
https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/attention_backend.md?plain=1#L22

@JackChuang JackChuang force-pushed the horenc/torch_native_kv8_support_on_main_release branch from 773c7a0 to 64ff639 Compare November 14, 2025 04:25
@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Nov 14, 2025
@JackChuang
Copy link
Copy Markdown
Contributor Author

@Fridge003 Thanks for your review and approval. Could someone help merge this PR? Thanks~

@Fridge003
Copy link
Copy Markdown
Collaborator

@JackChuang Please fix conflict

@JackChuang JackChuang force-pushed the horenc/torch_native_kv8_support_on_main_release branch from 64ff639 to 5faf913 Compare November 14, 2025 22:39
@JackChuang
Copy link
Copy Markdown
Contributor Author

@Fridge003 Could you please help merge this PR when you have free cycles? Thank you.

This patch fixes the issue where KV8 could not run when
the attention backend was set to torch_native.

Updates the attention backend support document.

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
@JackChuang JackChuang force-pushed the horenc/torch_native_kv8_support_on_main_release branch from 5faf913 to 8ee8295 Compare December 12, 2025 12:29
@Fridge003
Copy link
Copy Markdown
Collaborator

@JackChuang Do you have any example of accuracy benchmarking when enabling fp8 kv cache with torch native backend

@JackChuang
Copy link
Copy Markdown
Contributor Author

@JackChuang Do you have any example of accuracy benchmarking when enabling fp8 kv cache with torch native backend

Didn't test accuracy but performance. I’ll run the accuracy tests and then update.

@JackChuang
Copy link
Copy Markdown
Contributor Author

JackChuang commented Dec 16, 2025

@Fridge003 Using native_torch with KV8, the precision is essentially lossless.

[KV16]
Accuracy: 0.947
Invalid: 0.000
Latency: 2783.572 s
Output throughput: 73.740 token/s

$ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m sglang.launch_server --model-path /data02/models/Qwen3-235B-A22B --tp 4 --trust-remote-code --port 8041 --kv-cache-dtype fp8_e4m3 --disable-radix-cache --enable-torch-compile  --attention-backend torch_native

$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port 8041

[KV8]
Accuracy: 0.949
Invalid: 0.001
Latency: 2984.291 s
Output throughput: 68.772 token/s

$ CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server --model-path /data02/models/Qwen3-235B-A22B --tp 4 --trust-remote-code --port 8042 --attention-backend torch_native --disable-radix-cache  --enable-torch-compile

$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port 8042

Exp on B200

@Fridge003
Copy link
Copy Markdown
Collaborator

@JackChuang Please merge the main branch.

@Fridge003 Fridge003 merged commit 349ce2d into sgl-project:main Dec 28, 2025
256 of 272 checks passed
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
)

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants