[DeepSeek-V3.2] Fix TRT-LLM NSA in target_verify/draft_extend#17662
[DeepSeek-V3.2] Fix TRT-LLM NSA in target_verify/draft_extend#17662Fridge003 merged 3 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @mmangkad, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the speculative decoding mechanism when utilizing the TRT-LLM backend. The changes ensure that the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Pull request overview
This PR wires the TRT-LLM NSA decode path into the forward_extend (speculative/target_verify) flow and adjusts the TRT-LLM helper to accept per-token sequence lengths appropriate for speculative decoding. It aligns the NSA+TRT-LLM behavior with other NSA backends and fixes failures when using --nsa-decode-backend trtllm with speculative decoding.
Changes:
- Add an
nsa_impl == "trtllm"branch inforward_extendto route speculative extend/target_verify through_forward_trtllm. - Allow
_forward_trtllmto accept an optionalseq_lenstensor instead of always usingmetadata.cache_seqlens_int32. - For speculative
target_verify/draft_extendpaths, passmetadata.nsa_cache_seqlens_int32into_forward_trtllmso the TRT-LLM kernel sees per-token NSA-adjusted sequence lengths.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Code Review
The pull request integrates TRT-LLM handling into the Native Sparse Attention backend for speculative decoding. It adds a new elif branch in forward_extend to handle trtllm implementations and modifies _forward_trtllm to accept an optional seq_lens parameter, defaulting to metadata.cache_seqlens_int32 if not provided. This ensures correct sequence length handling for per-token decode in target_verify and draft_extend paths.
|
cc @Fridge003 PTAL |
|
Thanks, I will take a look later |
|
@mmangkad Can shifting decode kernel to trtllm implementation gain any acceleration? |
|
Also please verify the aime and gpqa scores as instruction here |
I saw that the original PR established the acceleration, showing the trtllm kernel (~14µs) is faster than flashmla (~26.5µs), while overall serving throughput was similar. This simply unblocks the extend phase so speculative decoding can run without crashing. |
Yes, I mean will changing decode kernel from flashmla to trtllm show any improvement when MTP is enabled |
I have updated the PR with full accuracy tests and benchmarks, which show a ~3-5% improvement over FlashMLA with MTP enabled. |
|
@mmangkad Thanks, what's your command for benchmarking |
I used: |
Motivation
Quick follow‑up to #16758: speculative decoding still fails with
--nsa-decode-backend trtllm.Modifications
target_verify/draft_extendonly).seq_lensfor_forward_trtllm, passing expandednsa_cache_seqlens_int32for speculative paths andcache_seqlens_int32in normal decode.Accuracy Tests
gsm8k 20 shots
gpqa-diamond
aime 2025
Benchmarking and Profiling
trtllm bf16 kv-cache
flashmla_sparse bf16 kv-cache
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci