Skip to content

[DeepSeek-V3.2] Fix TRT-LLM NSA in target_verify/draft_extend#17662

Merged
Fridge003 merged 3 commits intosgl-project:mainfrom
mmangkad-dev:fix-nsa-trtllm-target-verify
Jan 25, 2026
Merged

[DeepSeek-V3.2] Fix TRT-LLM NSA in target_verify/draft_extend#17662
Fridge003 merged 3 commits intosgl-project:mainfrom
mmangkad-dev:fix-nsa-trtllm-target-verify

Conversation

@mmangkad
Copy link
Copy Markdown
Contributor

@mmangkad mmangkad commented Jan 23, 2026

Motivation

Quick follow‑up to #16758: speculative decoding still fails with --nsa-decode-backend trtllm.

Modifications

  • Add TRT‑LLM handling in the extend path for speculative modes and guard it with a forward‑mode assertion (target_verify/draft_extend only).
  • Require explicit seq_lens for _forward_trtllm, passing expanded nsa_cache_seqlens_int32 for speculative paths and cache_seqlens_int32 in normal decode.

Accuracy Tests

gsm8k 20 shots

Accuracy: 0.952
Invalid: 0.000
Latency: 519.545 s
Output throughput: 247.709 token/s

gpqa-diamond

Repeat: 8, mean: 0.836
Scores: ['0.814', '0.819', '0.861', '0.835', '0.845', '0.830', '0.840', '0.845']

aime 2025

---------------------------------------- aime25 ----------------------------------------
evaluation_mode  | num_entries | avg_tokens | gen_seconds | symbolic_correct | no_answer
pass@1[avg-of-4] | 30          | 15002      | 973         | 92.11% ± 1.66%   | 0.00%
majority@4       | 30          | 15002      | 973         | 94.31%           | 0.00%
pass@4           | 30          | 15002      | 973         | 96.26%           | 0.00%

Benchmarking and Profiling

trtllm bf16 kv-cache

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 24
Successful requests:                     120
Benchmark duration (s):                  130.18
Total input tokens:                      44392
Total input text tokens:                 44392
Total generated tokens:                  26533
Total generated tokens (retokenized):    26392
Request throughput (req/s):              0.92
Input token throughput (tok/s):          341.02
Output token throughput (tok/s):         203.83
Peak output token throughput (tok/s):    1646.00
Peak concurrent requests:                28
Total token throughput (tok/s):          544.84
Concurrency:                             22.82
Accept length:                           2.78
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   24754.96
Median E2E Latency (ms):                 16999.09
P90 E2E Latency (ms):                    57919.79
P99 E2E Latency (ms):                    104408.16
---------------Time to First Token----------------
Mean TTFT (ms):                          2275.53
Median TTFT (ms):                        1985.65
P99 TTFT (ms):                           5471.25
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          170.95
Median TPOT (ms):                        129.97
P99 TPOT (ms):                           1681.90
---------------Inter-Token Latency----------------
Mean ITL (ms):                           102.18
Median ITL (ms):                         10.32
P95 ITL (ms):                            602.41
P99 ITL (ms):                            1204.27
Max ITL (ms):                            4088.89

flashmla_sparse bf16 kv-cache

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 24
Successful requests:                     120
Benchmark duration (s):                  134.20
Total input tokens:                      44392
Total input text tokens:                 44392
Total generated tokens:                  26533
Total generated tokens (retokenized):    26174
Request throughput (req/s):              0.89
Input token throughput (tok/s):          330.78
Output token throughput (tok/s):         197.71
Peak output token throughput (tok/s):    1709.00
Peak concurrent requests:                28
Total token throughput (tok/s):          528.49
Concurrency:                             22.57
Accept length:                           2.81
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   25236.27
Median E2E Latency (ms):                 16229.20
P90 E2E Latency (ms):                    59051.35
P99 E2E Latency (ms):                    114857.01
---------------Time to First Token----------------
Mean TTFT (ms):                          2413.67
Median TTFT (ms):                        2223.97
P99 TTFT (ms):                           5172.87
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          160.83
Median TPOT (ms):                        127.63
P99 TPOT (ms):                           1389.06
---------------Inter-Token Latency----------------
Mean ITL (ms):                           103.81
Median ITL (ms):                         10.64
P95 ITL (ms):                            589.90
P99 ITL (ms):                            1093.23
Max ITL (ms):                            3179.75
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Copilot AI review requested due to automatic review settings January 23, 2026 20:42
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @mmangkad, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the speculative decoding mechanism when utilizing the TRT-LLM backend. The changes ensure that the extend operation correctly handles sequence lengths for per-token decoding by introducing specific trtllm logic and allowing flexible sequence length overrides in the underlying attention forward pass. This fix is essential for the stability and functionality of speculative decoding with TRT-LLM.

Highlights

  • TRT-LLM Speculative Decoding Fix: Resolves an issue where speculative decoding failed when using the trtllm backend, specifically in the target_verify/draft_extend path, ensuring correct functionality.
  • TRT-LLM Handling in forward_extend: Introduced a dedicated code path within the forward_extend function to correctly handle trtllm implementations, including proper q_all concatenation and _forward_trtllm invocation.
  • Flexible Sequence Lengths for _forward_trtllm: Modified the _forward_trtllm method to accept an optional seq_lens parameter, allowing callers to override the default sequence length behavior for specific use cases.
  • Expanded Sequence Lengths for Speculative Paths: Utilizes metadata.nsa_cache_seqlens_int32 as the seq_lens when calling _forward_trtllm in the trtllm speculative decoding path, which is crucial for accurate per-token decode operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR wires the TRT-LLM NSA decode path into the forward_extend (speculative/target_verify) flow and adjusts the TRT-LLM helper to accept per-token sequence lengths appropriate for speculative decoding. It aligns the NSA+TRT-LLM behavior with other NSA backends and fixes failures when using --nsa-decode-backend trtllm with speculative decoding.

Changes:

  • Add an nsa_impl == "trtllm" branch in forward_extend to route speculative extend/target_verify through _forward_trtllm.
  • Allow _forward_trtllm to accept an optional seq_lens tensor instead of always using metadata.cache_seqlens_int32.
  • For speculative target_verify / draft_extend paths, pass metadata.nsa_cache_seqlens_int32 into _forward_trtllm so the TRT-LLM kernel sees per-token NSA-adjusted sequence lengths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request integrates TRT-LLM handling into the Native Sparse Attention backend for speculative decoding. It adds a new elif branch in forward_extend to handle trtllm implementations and modifies _forward_trtllm to accept an optional seq_lens parameter, defaulting to metadata.cache_seqlens_int32 if not provided. This ensures correct sequence length handling for per-token decode in target_verify and draft_extend paths.

Comment thread python/sglang/srt/layers/attention/nsa_backend.py
@mmangkad
Copy link
Copy Markdown
Contributor Author

cc @Fridge003 PTAL

@mmangkad mmangkad changed the title Fix TRT-LLM NSA in target_verify/draft_extend [DeepSeek-V3.2] Fix TRT-LLM NSA in target_verify/draft_extend Jan 24, 2026
@Fridge003
Copy link
Copy Markdown
Collaborator

Thanks, I will take a look later

@Fridge003
Copy link
Copy Markdown
Collaborator

@mmangkad Can shifting decode kernel to trtllm implementation gain any acceleration?
Maybe you can post some benchmark/profile results

@Fridge003
Copy link
Copy Markdown
Collaborator

Also please verify the aime and gpqa scores as instruction here
https://docs.sglang.io/basic_usage/deepseek_v32.html#accuracy-test-with-gpqa-diamond

@mmangkad
Copy link
Copy Markdown
Contributor Author

@mmangkad Can shifting decode kernel to trtllm implementation gain any acceleration? Maybe you can post some benchmark/profile results

I saw that the original PR established the acceleration, showing the trtllm kernel (~14µs) is faster than flashmla (~26.5µs), while overall serving throughput was similar. This simply unblocks the extend phase so speculative decoding can run without crashing.

@Fridge003
Copy link
Copy Markdown
Collaborator

@mmangkad Can shifting decode kernel to trtllm implementation gain any acceleration? Maybe you can post some benchmark/profile results

I saw that the original PR established the acceleration, showing the trtllm kernel (~14µs) is faster than flashmla (~26.5µs), while overall serving throughput was similar. This simply unblocks the extend phase so speculative decoding can run without crashing.

Yes, I mean will changing decode kernel from flashmla to trtllm show any improvement when MTP is enabled

@mmangkad
Copy link
Copy Markdown
Contributor Author

@mmangkad Can shifting decode kernel to trtllm implementation gain any acceleration? Maybe you can post some benchmark/profile results

I saw that the original PR established the acceleration, showing the trtllm kernel (~14µs) is faster than flashmla (~26.5µs), while overall serving throughput was similar. This simply unblocks the extend phase so speculative decoding can run without crashing.

Yes, I mean will changing decode kernel from flashmla to trtllm show any improvement when MTP is enabled

I have updated the PR with full accuracy tests and benchmarks, which show a ~3-5% improvement over FlashMLA with MTP enabled.

@Fridge003
Copy link
Copy Markdown
Collaborator

@mmangkad Thanks, what's your command for benchmarking

Comment thread python/sglang/srt/layers/attention/nsa_backend.py Outdated
Comment thread python/sglang/srt/layers/attention/nsa_backend.py
@mmangkad
Copy link
Copy Markdown
Contributor Author

@mmangkad Thanks, what's your command for benchmarking

I used:

python3 -m sglang.bench_serving \
  --backend sglang \
  --model DeepSeek-V3.2 \
  --num-prompts 120 \
  --random-input-len 2048 \
  --random-output-len 2048 \
  --random-range-ratio 0.5 \
  --max-concurrency 24 \
  --warmup-requests 5

@Fridge003 Fridge003 merged commit 1674b9e into sgl-project:main Jan 25, 2026
62 of 69 checks passed
@mmangkad mmangkad deleted the fix-nsa-trtllm-target-verify branch January 25, 2026 05:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants