[NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)#22921
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
The model has a repeated block pattern of 3× linear attention (GDN) + 1× full attention.
The GDN kernel itself is ~19% faster with FlashInfer; the modest system-level gain (~5%) FlashInfer GDN prefill — kernel breakdown (per layer, 11 launches)
Triton GDN prefill — kernel breakdown (per layer, 12 launches)
The ~80 µs gap between summed kernel times and wall time reflects Python-level kernel |
|
This PR is ready for review. |
|
The CuteDSL kernel performance is limited by low parallelism when batch size and number of heads are small, which is clearly shown by the kernel benchmark in flashinfer-ai/flashinfer#3001 Depending on how the prefill benchmark is configured, the e2e speedup will vary a lot. For example, for 1k or 8k ISL and --chunked-prefill-size 163840, and TP4, you get effect batch size 160 and 20 and will hit the higher end of the speedup. But if you set --chunked-prefill-size 8192, the effective batch size will be smaller and will hit the lower end of the speedup. In practice, the real speedup will depend on the real ISL of the workloads, and we likely won't see much speedup for the long ISL workloads. |
23b04c0 to
b6c0d39
Compare
| q_fi = l2norm_fwd(q[0].contiguous()) | ||
| k_fi = l2norm_fwd(k[0].contiguous()) |
There was a problem hiding this comment.
We can modify the triton l2norm_fwd kernel to make it support strided inputs to eliminate the contiguous calls
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
…fy on SM100+ (Blackwell) Resolved conflicts with PR sgl-project#22921: - gdn_flashinfer.py: combined module and class docstrings to reflect that SM100+ now supports decode, prefill, and MTP verify. - gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError guard so the pool-API MTP path runs on both SM90 and SM100+. - server_args.py: kept the bf16 gate from sgl-project#22921 and removed the speculative_algorithm gate now that MTP verify is supported on SM100+.
…fy on SM100+ (Blackwell) Resolved conflicts with PR sgl-project#22921: - gdn_flashinfer.py: combined module and class docstrings to reflect that SM100+ now supports decode, prefill, and MTP verify. - gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError guard so the pool-API MTP path runs on both SM90 and SM100+. - server_args.py: kept the bf16 gate from sgl-project#22921 and removed the speculative_algorithm gate now that MTP verify is supported on SM100+.
PR sgl-project#22921 renamed the SM-gating attribute from use_state_pool to is_sm100plus (updating all existing call sites). PR sgl-project#23273 was authored against the older name and added a new reference in the bf16 MTP adapter setup. The git auto-merge picked up sgl-project#22921's renames and sgl-project#23273's new block, leaving a single dangling use_state_pool access that crashed at FlashInferGDNKernel.__init__. Rename the one remaining reference to is_sm100plus to match the rest of the class.
|
ping @yizhang2077 |
|
/rerun-failed-ci + |
032c24c to
9496b9d
Compare
|
/rerun-failed-ci |
|
@kaixih Please rebase and resolve the conflicts. thanks! |
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
9496b9d to
241bc2c
Compare
|
/rerun-failed-ci |
https://github.com/sgl-project/sglang/actions/runs/26260336698/job/77301401879?pr=22921 |
|
@nvpohanh Thanks for flagging this. The CI registration was using an invalid CUDA suite name, so I updated it to I also manually ran the target test on a 4x B200 node in the latest SGLang dev container; it passed with GSM8K accuracy 0.980 vs the 0.950 baseline. The new red checks look unrelated: the other B200 shards were fast-failed due to a root failure in |
|
/tag-and-rerun-ci |
|
All the NV CI has hassed. @yuan-luo @Fridge003 could we merge this? Thanks! |
|
@ispobock Could you also help to review this GDN PR? Thanks! |
[GDN] Add FlashInfer prefill support for SM100+ (Blackwell)
Summary
Extends FlashInfer GDN kernel support to cover the prefill/extend path on SM100+
(Blackwell) hardware, previously raising
NotImplementedError. SM90 (Hopper)prefill was already supported; this PR completes SM100+ coverage.
Accuracy (Qwen3.5-397B-A17B-NVFP4, B200)
gsm8k (200 examples, baseline threshold: 0.95)
GPQA diamond (198 examples, repeat=8, temperature=0.6)
Throughput Benchmark (B200, Qwen3.5-397B-A17B-NVFP4, TP=8)
More detailed perf numbers in the PR comments below.
Server settings:
--tp-size 8 --max-running-requests 256 --chunked-prefill-size 163840--mamba-ssm-dtype bfloat16 --mamba-scheduler-strategy no_buffer --mamba-track-interval 128--attention-backend trtllm_mha --linear-attn-decode-backend flashinfer--linear-attn-prefill-backend <triton|flashinfer>(varied per run)--disable-radix-cache --quantization modelopt_fp4Benchmark settings:
--dataset-name random --random-input-len 8192 --random-output-len 128--max-concurrency 256 --num-prompts 512Requirements
chunk_gated_delta_ruleSM100 path)nvidia-cutlass-dsl[cu13] >= 4.4.2(SM100+ only)_cuda_major >= 13)CI States
Latest PR Test (Base): ✅ Run #26271103993
Latest PR Test (Extra): ❌ Run #26271103852