Skip to content

[RL] [DSv32] [GLM-5] Add --nsa-topk-backend and integrate FlashInfer and pytorch topk#22851

Open
zianglih wants to merge 8 commits intosgl-project:mainfrom
zianglih:torch-topk
Open

[RL] [DSv32] [GLM-5] Add --nsa-topk-backend and integrate FlashInfer and pytorch topk#22851
zianglih wants to merge 8 commits intosgl-project:mainfrom
zianglih:torch-topk

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 15, 2026

Motivation

@HumansAnd

Add --nsa-topk-backend for configurable topk backend implementation selection.

torch.topk is used by GLM-5 for RL.
FlashInfer topk has determinism and configurable tie break (flashinfer-ai/flashinfer#3095), and better long context performance.

Modifications

  • Add --nsa-topk-backend, default to existing sgl-kernel
  • Integrate flashinfer and torch topk for unfused code path
  • Integrate flashinfer topk for fused code path
  • Add SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK and SGLANG_NSA_TOPK_FLASHINFER_DETERMINISTIC
  • Add new unit test

Accuracy Tests

New unit test python3 -m pytest -q test/registered/kernels/test_nsa_indexer.py -k test_topk_unfused_backends_valid_selection passed.

SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 python3 -m sglang.launch_server --nsa-topk-backend sgl-kernel --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend nsa --nsa-decode-backend flashmla_sparse --nsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.977
Invalid: 0.000
Latency: 13.146 s
Output throughput: 8566.746 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 12.749 s
Output throughput: 8878.813 token/s
Accuracy: 0.981
Invalid: 0.000
Latency: 17.272 s
Output throughput: 6584.294 token/s
# torch unfused
SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_NSA_FUSE_TOPK=0 python3 -m sglang.launch_server --nsa-topk-backend torch --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend nsa --nsa-decode-backend flashmla_sparse --nsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.982
Invalid: 0.000
Latency: 18.256 s
Output throughput: 6183.790 token/s
Accuracy: 0.983
Invalid: 0.000
Latency: 17.637 s
Output throughput: 6388.987 token/s
Accuracy: 0.980
Invalid: 0.000
Latency: 17.609 s
Output throughput: 6403.039 token/s
# flashinfer unfused
SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_NSA_FUSE_TOPK=0 python3 -m sglang.launch_server --nsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend nsa --nsa-decode-backend flashmla_sparse --nsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.978
Invalid: 0.000
Latency: 20.846 s
Output throughput: 5413.876 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 24.896 s
Output throughput: 4557.003 token/s
Accuracy: 0.979
Invalid: 0.000
Latency: 21.313 s
Output throughput: 5292.839 token/s
# flashinfer fused
SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_NSA_FUSE_TOPK=1 python3 -m sglang.launch_server --nsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend nsa --nsa-decode-backend flashmla_sparse --nsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.980
Invalid: 0.000
Latency: 13.531 s
Output throughput: 8320.213 token/s
Accuracy: 0.981
Invalid: 0.000
Latency: 12.771 s
Output throughput: 8832.274 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 12.121 s
Output throughput: 9267.255 token/s
# flashinfer fused with tie_break=1
SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK=1 SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_NSA_FUSE_TOPK=1 python3 -m sglang.launch_server --nsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend nsa --nsa-decode-backend flashmla_sparse --nsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.978
Invalid: 0.000
Latency: 13.716 s
Output throughput: 8219.616 token/s
Accuracy: 0.980
Invalid: 0.000
Latency: 13.008 s
Output throughput: 8652.700 token/s
Accuracy: 0.978
Invalid: 0.000
Latency: 17.669 s
Output throughput: 6457.714 token/s
# flashinfer fused with tie_break=2
SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK=2 SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD=0 SGLANG_NSA_FUSE_TOPK=1 python3 -m sglang.launch_server --nsa-topk-backend flashinfer --kv-cache-dtype bf16 --model /data/models/ziangli_v32/DeepSeek-V3.2 --tp 8 --dp 8 --enable-dp-attention --moe-runner-backend flashinfer_trtllm_routed --attention-backend nsa --nsa-decode-backend flashmla_sparse --nsa-prefill-backend flashmla_sparse --page-size 64 --trust-remote-code
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.979
Invalid: 0.000
Latency: 13.370 s
Output throughput: 8438.633 token/s
Accuracy: 0.982
Invalid: 0.000
Latency: 13.129 s
Output throughput: 8628.890 token/s
Accuracy: 0.980
Invalid: 0.000
Latency: 12.498 s
Output throughput: 9047.713 token/s

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Apr 15, 2026
@ziang-and ziang-and requested a review from wisclmy0611 as a code owner April 21, 2026 02:54
@zianglih zianglih changed the title [RL] [V3.2] [GLM-5] Add SGLANG_NSA_TORCH_TOPK [RL] [DSv32] [GLM-5] Add --nsa-topk-backend and integrate FlashInfer and pytorch topk Apr 21, 2026
@nvpohanh
Copy link
Copy Markdown
Collaborator

cc @nvjullin

@DarkSharpness
Copy link
Copy Markdown
Collaborator

qq: Does flashinfer kernel support cuda-graph? I know flashinfer may dispatch to different algorithms based on static sequence length, but is that safe under CUDA graph?

@zianglih
Copy link
Copy Markdown
Contributor Author

Hi @DarkSharpness , thank you for calling this out. This is indeed a valid concern. Current FlashInfer's dispatch heuritics use max_len, which is not CUDA graph safe in current implementation. We are also working with CCCL team for a graph safe topk (flashinfer-ai/flashinfer#3091 etc) which will be integrated into flashinfer soon. As of now for this PR we can disallow cuda graph if flashinfer topk backend is used.

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Apr 21, 2026

Hold until flashinfer-ai/flashinfer#3133 , which introduces a graph safe mode.

kahyunnam pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Apr 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description
@HumansAnd
Parent PR: #3095
SGLang PR: sgl-project/sglang#22851

Add `row_starts` and `dsa_graph_safe` for SGLang DSA integration.
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues
sgl-project/sglang#22851 (comment)

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added dsa_graph_safe flag to top-k APIs to opt into DSA-graph safe
execution.
* Added optional row_starts parameter to page-table and ragged top-k
transforms to support per-row score offsets.

* **Behavior**
* When dsa_graph_safe=True the optimized clusters fast-path is disabled
to ensure safe execution.

* **Tests**
* Added tests covering row_starts behavior for page-table and ragged
transforms.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
@zianglih
Copy link
Copy Markdown
Contributor Author

Hold until flashinfer v0.6.10 release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants