Skip to content

[NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)#22921

Merged
Fridge003 merged 4 commits into
sgl-project:mainfrom
kaixih:add_flashinfer_gdn_prefill
May 27, 2026
Merged

[NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)#22921
Fridge003 merged 4 commits into
sgl-project:mainfrom
kaixih:add_flashinfer_gdn_prefill

Conversation

@kaixih

@kaixih kaixih commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator

[GDN] Add FlashInfer prefill support for SM100+ (Blackwell)

Summary

Extends FlashInfer GDN kernel support to cover the prefill/extend path on SM100+
(Blackwell) hardware, previously raising NotImplementedError. SM90 (Hopper)
prefill was already supported; this PR completes SM100+ coverage.

Accuracy (Qwen3.5-397B-A17B-NVFP4, B200)

gsm8k (200 examples, baseline threshold: 0.95)

Backend Score
Triton (prefill + decode) 0.985
FlashInfer (prefill + decode) 0.985

GPQA diamond (198 examples, repeat=8, temperature=0.6)

Backend Scores Mean
FlashInfer (prefill + decode) 0.848, 0.879, 0.904, 0.879, 0.848, 0.864, 0.869, 0.869 0.870

Throughput Benchmark (B200, Qwen3.5-397B-A17B-NVFP4, TP=8)

More detailed perf numbers in the PR comments below.

Server settings:

  • --tp-size 8 --max-running-requests 256 --chunked-prefill-size 163840
  • --mamba-ssm-dtype bfloat16 --mamba-scheduler-strategy no_buffer --mamba-track-interval 128
  • --attention-backend trtllm_mha --linear-attn-decode-backend flashinfer
  • --linear-attn-prefill-backend <triton|flashinfer> (varied per run)
  • --disable-radix-cache --quantization modelopt_fp4

Benchmark settings:

  • --dataset-name random --random-input-len 8192 --random-output-len 128
  • --max-concurrency 256 --num-prompts 512
Metric Triton prefill FlashInfer prefill Speedup
Benchmark duration (s) 53.27 50.87 1.05x
Input throughput (tok/s) 78,734 82,445 1.05x
Total throughput (tok/s) 79,964 83,733 1.05x
Mean TTFT (ms) 12,742 12,042 1.06x
Mean TPOT (ms) 109.08 105.14 1.04x

Requirements

  • FlashInfer >= 0.6.8 (for chunk_gated_delta_rule SM100 path)
  • nvidia-cutlass-dsl[cu13] >= 4.4.2 (SM100+ only)
  • CUDA 13 (SM100+ path requires _cuda_major >= 13)

CI States

Latest PR Test (Base): ✅ Run #26271103993
Latest PR Test (Extra): ❌ Run #26271103852

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@kaixih

kaixih commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

cc @hlu1 @YAMY1234 @wenscarl

@kaixih kaixih changed the title [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) [Draft] [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) Apr 16, 2026
@kaixih kaixih changed the title [Draft] [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) Apr 16, 2026
@kaixih

kaixih commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

The model has a repeated block pattern of 3× linear attention (GDN) + 1× full attention.
Profiling one such block during prefill:

Backend Block wall time GDN prefill (3 layers) GDN per layer Kernels/layer
Triton 12,784 µs 1,518 µs (506×3) 506 µs 12
FlashInfer 12,379 µs 1,275 µs (425×3) 425 µs 11
Speedup 1.03x 1.19x 1.19x

The GDN kernel itself is ~19% faster with FlashInfer; the modest system-level gain (~5%)
reflects that GDN is a small fraction of the total forward pass (MoE GEMM, attention,
all-reduce account for the rest).

FlashInfer GDN prefill — kernel breakdown (per layer, 11 launches)

Kernel Calls Time
GatedDeltaNetChunkedKernel (fused main) 1 328.2 µs
elementwise_kernel (bf16 contiguity copy, packed QKV) 3 58.2 µs (19.4 µs each)
l2norm_fwd_kernel 2 7.5 µs (3.7 µs each)
index_elementwise_kernel (index_copy scatter) 1 2.9 µs
vectorized_gather_kernel (state gather) 1 2.5 µs
vectorized_elementwise_kernel (exp) 1 2.4 µs
unrolled_elementwise_kernel (int64 cast for index_copy) 1 2.2 µs
vectorized_elementwise_kernel (clamp) 1 2.0 µs
Total 11 ≈406 µs (wall: 425 µs)

Triton GDN prefill — kernel breakdown (per layer, 12 launches)

Kernel Calls Time
chunk_gated_delta_rule_fwd_kernel_h_blockdim64 (main recurrence) 1 257.9 µs
chunk_fwd_kernel_o (output projection) 1 63.5 µs
elementwise_kernel (bf16 contiguity copy, packed QKV) 3 56.8 µs (18.9 µs each)
chunk_gated_delta_rule_fwd_kkt_solve_kernel 1 42.2 µs
recompute_w_u_fwd_kernel 1 34.2 µs
vectorized_elementwise_kernel (fill bf16) 2 15.6 µs (7.8 µs each)
l2norm_fwd_kernel 2 9.0 µs (4.5 µs each)
chunk_local_cumsum_scalar_kernel 1 4.8 µs
Total 12 ≈484 µs (wall: 506 µs)

The ~80 µs gap between summed kernel times and wall time reflects Python-level kernel
launch overhead (gaps between dispatches). The FlashInfer overhead items above
(packed QKV copies, gather/scatter, l2norm, exp, cast, clamp — ~78 µs) are candidates
for elimination via the upstream improvements listed above.

@kaixih

kaixih commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

This PR is ready for review.

@hlu1

hlu1 commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator

The CuteDSL kernel performance is limited by low parallelism when batch size and number of heads are small, which is clearly shown by the kernel benchmark in flashinfer-ai/flashinfer#3001

Depending on how the prefill benchmark is configured, the e2e speedup will vary a lot. For example, for 1k or 8k ISL and --chunked-prefill-size 163840, and TP4, you get effect batch size 160 and 20 and will hit the higher end of the speedup. But if you set --chunked-prefill-size 8192, the effective batch size will be smaller and will hit the lower end of the speedup. In practice, the real speedup will depend on the real ISL of the workloads, and we likely won't see much speedup for the long ISL workloads.

Comment thread python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py Outdated
Comment thread python/sglang/srt/server_args.py Outdated
Comment thread python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py
Comment on lines 194 to 195
q_fi = l2norm_fwd(q[0].contiguous())
k_fi = l2norm_fwd(k[0].contiguous())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can modify the triton l2norm_fwd kernel to make it support strided inputs to eliminate the contiguous calls

@ispobock

Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yuan-luo

Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

Comment thread python/sglang/srt/server_args.py
@yuan-luo

Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yuan-luo

Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

mmangkad added a commit to mmangkad/sglang that referenced this pull request Apr 28, 2026
…fy on SM100+ (Blackwell)

Resolved conflicts with PR sgl-project#22921:
- gdn_flashinfer.py: combined module and class docstrings to reflect that
  SM100+ now supports decode, prefill, and MTP verify.
- gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError
  guard so the pool-API MTP path runs on both SM90 and SM100+.
- server_args.py: kept the bf16 gate from sgl-project#22921 and removed the
  speculative_algorithm gate now that MTP verify is supported on SM100+.
mmangkad added a commit to mmangkad/sglang that referenced this pull request Apr 28, 2026
…fy on SM100+ (Blackwell)

Resolved conflicts with PR sgl-project#22921:
- gdn_flashinfer.py: combined module and class docstrings to reflect that
  SM100+ now supports decode, prefill, and MTP verify.
- gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError
  guard so the pool-API MTP path runs on both SM90 and SM100+.
- server_args.py: kept the bf16 gate from sgl-project#22921 and removed the
  speculative_algorithm gate now that MTP verify is supported on SM100+.
mmangkad added a commit to mmangkad/sglang that referenced this pull request Apr 28, 2026
PR sgl-project#22921 renamed the SM-gating attribute from use_state_pool to
is_sm100plus (updating all existing call sites). PR sgl-project#23273 was authored
against the older name and added a new reference in the bf16 MTP adapter
setup. The git auto-merge picked up sgl-project#22921's renames and sgl-project#23273's new
block, leaving a single dangling use_state_pool access that crashed at
FlashInferGDNKernel.__init__.

Rename the one remaining reference to is_sm100plus to match the rest of
the class.
@kaixih

kaixih commented May 7, 2026

Copy link
Copy Markdown
Collaborator Author

ping @yizhang2077

@samuellees

samuellees commented May 11, 2026

Copy link
Copy Markdown
Contributor

/rerun-failed-ci +

@kaixih kaixih force-pushed the add_flashinfer_gdn_prefill branch 2 times, most recently from 032c24c to 9496b9d Compare May 12, 2026 21:17
@yuan-luo

Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@nvpohanh

Copy link
Copy Markdown
Collaborator

@kaixih Please rebase and resolve the conflicts. thanks!

@kaixih kaixih force-pushed the add_flashinfer_gdn_prefill branch from 9496b9d to 241bc2c Compare May 22, 2026 00:02
@nvpohanh

Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@nvpohanh

Copy link
Copy Markdown
Collaborator
Traceback (most recent call last):
  File "/actions-runner/_work/sglang/sglang/test/run_suite.py", line 421, in <module>
    main()
  File "/actions-runner/_work/sglang/sglang/test/run_suite.py", line 416, in main
    exit_code = run_a_suite(args)
  File "/actions-runner/_work/sglang/sglang/test/run_suite.py", line 295, in run_a_suite
    validate_all_suites(all_tests)
  File "/actions-runner/_work/sglang/sglang/test/run_suite.py", line 171, in validate_all_suites
    raise ValueError("Tests registered to invalid suites:\n" + "\n".join(errors))
ValueError: Tests registered to invalid suites:
  /actions-runner/_work/sglang/sglang/test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py: backend=CUDA, suite='stage-c-test-4-gpu-b200'

https://github.com/sgl-project/sglang/actions/runs/26260336698/job/77301401879?pr=22921
@kaixih Could you fix this?

@kaixih

kaixih commented May 22, 2026

Copy link
Copy Markdown
Collaborator Author

@nvpohanh Thanks for flagging this. The CI registration was using an invalid CUDA suite name, so I updated it to stage="base-c", runner_config="4-gpu-b200", which resolves to base-c-test-4-gpu-b200.

I also manually ran the target test on a 4x B200 node in the latest SGLang dev container; it passed with GSM8K accuracy 0.980 vs the 0.950 baseline. The new red checks look unrelated: the other B200 shards were fast-failed due to a root failure in base-c-test-4-gpu-h100 (0).

@nvpohanh

Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@nvpohanh

Copy link
Copy Markdown
Collaborator

All the NV CI has hassed. @yuan-luo @Fridge003 could we merge this? Thanks!

@nvpohanh

Copy link
Copy Markdown
Collaborator

@ispobock Could you also help to review this GDN PR? Thanks!

@Fridge003 Fridge003 merged commit ddf0627 into sgl-project:main May 27, 2026
227 of 254 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants