Skip to content

[NVIDIA] [GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell)#23273

Merged
Fridge003 merged 6 commits into
sgl-project:mainfrom
wenscarl:gdnmtp_decode
Jun 2, 2026
Merged

[NVIDIA] [GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell)#23273
Fridge003 merged 6 commits into
sgl-project:mainfrom
wenscarl:gdnmtp_decode

Conversation

@wenscarl

@wenscarl wenscarl commented Apr 20, 2026

Copy link
Copy Markdown
Collaborator

[GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell)

co-authored by @YAMY1234 (main contributor)

Summary

Enables FlashInfer GDN MTP (speculative decoding) verify on SM100+ (Blackwell) hardware, previously raising NotImplementedError. SM90 (Hopper) MTP was already supported; this PR completes SM100+ coverage.

Root cause: target_verify guarded on use_state_pool, blocking SM100+ even though the FlashInfer gated_delta_rule_mtp kernel already accepts initial_state_indices (pool API) — the same API used by the SM90 path.

Changes (2 files, ~15 lines):

  • gdn_flashinfer.py: remove use_state_pool guard in target_verify; unify SM90 + SM100+ into a single pool-API path; add A_log.detach().float() cast (matches SM100+ decode path, no-op on SM90).
  • server_args.py: remove and self.speculative_algorithm is None from the SM100+ FlashInfer auto-default — FlashInfer is now safe to default on SM100+ regardless of whether MTP is enabled.

Accuracy (Qwen3.5-397B-A17B-NVFP4, B200)

gsm8k (TODO: examples, baseline threshold: 0.95)

SGLANG_ENABLE_SPEC_V2=1 python3 -m sglang.launch_server --model-path nvidia/Qwen3.5-397B-A17B-NVFP4 --tokenizer-path nvidia/Qwen3.5-397B-A17B-NVFP4 --trust-remote-code --host 0.0.0.0 --port 8000 --tp-size 4 --chunked-prefill-size 2048 --mamba-scheduler-strategy extra_buffer --mamba-track-interval 128 --mamba-ssm-dtype bfloat16 --max-running-requests 128 --reasoning-parser qwen3 --attention-backend trtllm_mha --quantization modelopt_fp4 --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --mem-fraction-static 0.8 --model-loader-extra-config '{"enable_multithread_load": true,"num_threads": 64}'

python3 -m sglang.test.run_eval   --model nvidia/Qwen3.5-397B-A17B-NVFP4   --eval-name gsm8k   --num-shots 5   --num-examples 200   --max-tokens 16000   --num-threads 128   --repeat 1   --temperature 0.6   --top-p 0.95   --top-k 20   --base-url http://127.0.0.1:8000   --host http://127.0.0.1   --port 8000
Backend Score
Triton (decode + MTP) 0.985
FlashInfer (decode + MTP) 0.980

GPQA Diamond (TODO: examples, repeat=8, temperature=0.6)

fi:
Total latency: 247.500 s
Score: 0.859
Output throughput: 6286.781 token/s

and

triton
Total latency: 253.352 s
Score: 0.854
Output throughput: 6196.159 token/
Backend Scores
Triton (decode + MTP) 0.854
FlashInfer (decode + MTP) 0.859

Throughput Benchmark (GB200, Qwen3.5-397B-A17B-NVFP4, TP=4)

Focus: long output sequence length (OSL), where per-step GDN state-update cost is most significant.

Server settings:

  • --tp-size 4 --max-running-requests 128
  • --mamba-ssm-dtype bfloat16 --mamba-scheduler-strategy no_buffer --mamba-track-interval 128
  • --attention-backend trtllm_mha --linear-attn-decode-backend <triton|flashinfer>
  • --speculative-algorithm NEXTN (MTP runs)
  • --disable-radix-cache --quantization modelopt_fp4

Benchmark settings:

  • --dataset-name random --random-input-len 32 --random-output-len <512|1024|2048|4096>
  • --num-prompts <varied> --request-rate inf

Decode throughput (w/ MTP), output throughput( tok/s) — ISL=32
acc len: 3.13-3.22
num_prompts: 256

OSL no-MTP MTP Speedup
1024 2731.86 3682.65 1.35x
2048 2937.99 4329.87 1.47x
4096 2915.84 4831.15 1.66x
OSL Triton FlashInfer Speedup
1024 3645.60 3682.65 1.01x
2048 4145.32 4329.87 1.04x
4096 4707.04 4831.15 1.03x

Mean TPOT (ms/tok), ISL=32, OSL=512

concurrency FlashInfer TPOT (ms) Triton TPOT (ms) Speedup (Triton / FlashInfer) Winner
1 3.23 3.34 1.034 FlashInfer
4 4.89 4.94 1.010 FlashInfer
16 10.21 10.29 1.008 FlashInfer
32 15.58 15.79 1.013 FlashInfer
64 23.29 24.64 1.058 FlashInfer
128 32.86 34.25 1.042 FlashInfer
256 31.14 31.62 1.015 FlashInfer

with flashinfer-ai/flashinfer#3147 and

diff --git a/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py b/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py
index 55baf0d75..9e47ead09 100644
--- a/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py
+++ b/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py
@@ -335,6 +335,14 @@ class FlashInferGDNKernel(LinearAttnKernelBase):
         a_mtp = a.view(batch_size, draft_token_num, num_v_heads)
         b_mtp = b.view(batch_size, draft_token_num, num_v_heads)

+        # FlashInfer's MTP kernel treats `intermediate_states_buffer` as
+        # batch-scoped (indexed by per-call batch index i_n), not pool-scoped.
+        # The pool buffer is allocated with leading dim = spec_state_size + 1,
+        # so slice to exactly B rows. Downstream scatter reads dim 1 of the
+        # buffer by request index, so the sliced view stays consistent.
+        if intermediate_states_buffer is not None:
+            intermediate_states_buffer = intermediate_states_buffer[:batch_size]
Concurrency FlashInfer TPOT (ms) Triton TPOT (ms) Speedup (Triton / FI) Winner
1 2.95 2.94 0.997 Triton
4 3.87 3.86 0.997 Triton
16 6.04 6.02 0.997 Triton
32 7.79 7.83 1.005 FlashInfer
64 10.64 10.67 1.003 FlashInfer
128 13.48 13.53 1.004 FlashInfer

Requirements

The traces are collected at ISL: 32 OSL: 512, CC: 64
Flashinfer:
Screenshot 2026-04-23 at 3 21 58 PM

triton:
Screenshot 2026-04-23 at 3 26 48 PM


CI States

Latest PR Test (Base): ❌ Run #26703045174
Latest PR Test (Extra): ❌ Run #26703045140

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@wenscarl wenscarl marked this pull request as ready for review April 22, 2026 14:56
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

mmangkad added a commit to mmangkad/sglang that referenced this pull request Apr 28, 2026
…fy on SM100+ (Blackwell)

Resolved conflicts with PR sgl-project#22921:
- gdn_flashinfer.py: combined module and class docstrings to reflect that
  SM100+ now supports decode, prefill, and MTP verify.
- gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError
  guard so the pool-API MTP path runs on both SM90 and SM100+.
- server_args.py: kept the bf16 gate from sgl-project#22921 and removed the
  speculative_algorithm gate now that MTP verify is supported on SM100+.
mmangkad added a commit to mmangkad/sglang that referenced this pull request Apr 28, 2026
…fy on SM100+ (Blackwell)

Resolved conflicts with PR sgl-project#22921:
- gdn_flashinfer.py: combined module and class docstrings to reflect that
  SM100+ now supports decode, prefill, and MTP verify.
- gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError
  guard so the pool-API MTP path runs on both SM90 and SM100+.
- server_args.py: kept the bf16 gate from sgl-project#22921 and removed the
  speculative_algorithm gate now that MTP verify is supported on SM100+.
mmangkad added a commit to mmangkad/sglang that referenced this pull request Apr 28, 2026
PR sgl-project#22921 renamed the SM-gating attribute from use_state_pool to
is_sm100plus (updating all existing call sites). PR sgl-project#23273 was authored
against the older name and added a new reference in the bf16 MTP adapter
setup. The git auto-merge picked up sgl-project#22921's renames and sgl-project#23273's new
block, leaving a single dangling use_state_pool access that crashed at
FlashInferGDNKernel.__init__.

Rename the one remaining reference to is_sm100plus to match the rest of
the class.
willhu-jpg added a commit to modal-labs/sglang that referenced this pull request May 15, 2026
@nvpohanh

Copy link
Copy Markdown
Collaborator

@wenscarl Could you rebase so that I can trigger CI? thanks

@nvpohanh

Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@nvpohanh

Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@nvpohanh

Copy link
Copy Markdown
Collaborator
  File "/usr/local/lib/python3.10/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/_mlir/dialects/_cuda_ops_gen.py", line 567, in __init__
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
TypeError: __init__(): incompatible function arguments. The following argument types are supported:
    1. __init__(self, operation: object) -> None

Invoked with types: cutlass._mlir.dialects.cuda.KernelOp, str, tuple, NoneType, NoneType, kwargs = { attributes: dict, results: list, operands: list, successors: NoneType, regions: NoneType, loc: cutlass._mlir._mlir_libs._cutlass_ir._mlir.ir.Location, ip: NoneType }

✗ FAILED: /actions-runner/_work/sglang/sglang/test/registered/hicache/test_hicache_variants.py returned exit code 1

https://github.com/sgl-project/sglang/actions/runs/26224310402/job/77302047639?pr=23273

Seeing this repeatedly 🤔

@nvpohanh

Copy link
Copy Markdown
Collaborator
  File "/usr/local/lib/python3.10/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/_mlir/dialects/_cuda_ops_gen.py", line 567, in __init__
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
TypeError: __init__(): incompatible function arguments. The following argument types are supported:
    1. __init__(self, operation: object) -> None

Invoked with types: cutlass._mlir.dialects.cuda.KernelOp, str, tuple, NoneType, NoneType, kwargs = { attributes: dict, results: list, operands: list, successors: NoneType, regions: NoneType, loc: cutlass._mlir._mlir_libs._cutlass_ir._mlir.ir.Location, ip: NoneType }

✗ FAILED: /actions-runner/_work/sglang/sglang/test/registered/hicache/test_hicache_variants.py returned exit code 1

https://github.com/sgl-project/sglang/actions/runs/26224310402/job/77302047639?pr=23273

Seeing this repeatedly 🤔

This has been fixed by #25958

@wenscarl let's merge the latest main. Thanks

Comment thread python/sglang/srt/server_args.py
YAMY1234 added a commit to wenscarl/sglang that referenced this pull request May 29, 2026
…backend

Locks down the existing Qwen3.5 NVFP4 MTP test to Triton backend so the
Triton coverage is preserved after this PR removes the
`speculative_algorithm is None` guard from the SM100+ FlashInfer
auto-default, and adds a parallel test class that explicitly exercises
the new FlashInfer GDN MTP verify path.

Addresses reviewer comment on PR sgl-project#23273.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…backend

Locks down the existing Qwen3.5 NVFP4 MTP test to Triton backend so the
Triton coverage is preserved after this PR removes the
`speculative_algorithm is None` guard from the SM100+ FlashInfer
auto-default, and adds a parallel test class that explicitly exercises
the new FlashInfer GDN MTP verify path.

Addresses reviewer comment on PR sgl-project#23273.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@YAMY1234 YAMY1234 self-requested a review May 29, 2026 16:46
@YAMY1234 YAMY1234 self-assigned this May 31, 2026
@YAMY1234

Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@nvpohanh

nvpohanh commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@nvpohanh

nvpohanh commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

H20 failure is a known issue fixed by #26883

@Fridge003 could we merge this?

@Fridge003 Fridge003 merged commit 0574d2b into sgl-project:main Jun 2, 2026
348 of 401 checks passed
xjpang pushed a commit to xjpang/sglang that referenced this pull request Jun 2, 2026
…l-project#23273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
mqhc2020 pushed a commit to mqhc2020/sglang that referenced this pull request Jun 2, 2026
…l-project#23273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hanming-lu pushed a commit that referenced this pull request Jun 3, 2026
…3273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
willhu-jpg pushed a commit to modal-projects/sglang that referenced this pull request Jun 3, 2026
…l-project#23273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
alphabetc1 pushed a commit to alphabetc1/sglang that referenced this pull request Jun 4, 2026
…l-project#23273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jeynmann pushed a commit to jeynmann/sglang that referenced this pull request Jun 4, 2026
…l-project#23273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
edwingao28 pushed a commit to edwingao28/sglang that referenced this pull request Jun 7, 2026
…l-project#23273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
monkeyLoveding pushed a commit to monkeyLoveding/sglang_open that referenced this pull request Jun 9, 2026
…l-project#23273)

Co-authored-by: Yangmin Li <yangminl@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants