Skip to content

use flashinfer.sampling#18696

Merged
BBuf merged 5 commits intosgl-project:mainfrom
pansicheng:flashinfer-sampling
Feb 26, 2026
Merged

use flashinfer.sampling#18696
BBuf merged 5 commits intosgl-project:mainfrom
pansicheng:flashinfer-sampling

Conversation

@pansicheng
Copy link
Copy Markdown
Collaborator

Motivation

#17865
move (external) flashinfer/csrc/sampling.cu
Call flashinfer directly from python, instead of compiling the operators into sgl_kernel

Modifications

Accuracy Tests

unittest python -m pytest sgl-kernel/tests/test_sampling.py -s

gsm8k

python -m sglang.launch_server --model /data/Qwen3-8B/
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 128
100%|███████████████████████████████████████████████████████████████| 1319/1319 [01:08<00:00, 19.13it/s]
Accuracy: 0.904
Invalid: 0.000
Latency: 68.988 s
Output throughput: 2507.349 token/s

python -m sglang.launch_server --model /data/Qwen3-8B/ --disable-radix-cache
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 128
100%|███████████████████████████████████████████████████████████████| 1319/1319 [06:12<00:00,  3.54it/s]
Accuracy: 0.901
Invalid: 0.000
Latency: 372.805 s
Output throughput: 460.144 token/s

Benchmarking and Profiling

python benchmark/bench_top_k_top_p_sampling.py

this patch
============================================================
Starting performance benchmark...
top-k-top-p-joint-sampling-performance:
    batch_size  vocab_size    p  Torch Reference  SGL Kernel
0         16.0       111.0  0.1      3517.440081   34.816001
1         16.0       111.0  0.5      3519.488096   29.696001
2         16.0     32000.0  0.1      4187.648058  191.487998
3         16.0     32000.0  0.5      4232.192039  163.839996
4         64.0       111.0  0.1     14012.415886   38.911998
5         64.0       111.0  0.5     13886.464119   34.816001
6         64.0     32000.0  0.1     16578.559875  243.711993
7         64.0     32000.0  0.5     16712.703705  214.528002
8        128.0       111.0  0.1     27687.936783   47.104001
9        128.0       111.0  0.5     27639.808655   39.935999
10       128.0     32000.0  0.1     33045.503616  319.487989
11       128.0     32000.0  0.5     33159.679413  273.407996

main
============================================================
Starting performance benchmark...
top-k-top-p-joint-sampling-performance:
    batch_size  vocab_size    p  Torch Reference  SGL Kernel
0         16.0       111.0  0.1      3526.655912   33.792000
1         16.0       111.0  0.5      3526.143909   28.672000
2         16.0     32000.0  0.1      4183.552027  226.303995
3         16.0     32000.0  0.5      4174.335957  193.024002
4         64.0       111.0  0.1     13879.296303   37.888002
5         64.0       111.0  0.5     13943.807602   33.792000
6         64.0     32000.0  0.1     16728.063583  281.087995
7         64.0     32000.0  0.5     16711.679459  246.784002
8        128.0       111.0  0.1     27573.247910   45.056000
9        128.0       111.0  0.5     27688.959122   39.935999
10       128.0     32000.0  0.1     33324.544907  370.687991
11       128.0     32000.0  0.5     33060.352325  314.368010

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 14, 2026

It's cool to saw some performance improve.

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job.

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 14, 2026

/tag-and-rerun-ci

if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_p_sampling_from_probs_internal(
return get_sampling_module().top_p_sampling_from_probs(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not recommend directly get module from flashinfer (It's not a public API). You may take a look at my implementation in mini-sglang as a reference:
https://github.com/sgl-project/mini-sglang/blob/82722ad6dc85df766278c48061d768b4117a3bd4/python/minisgl/engine/sample.py#L24-L45

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, PTAL

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 24, 2026

@DarkSharpness Any other advices?

@DarkSharpness
Copy link
Copy Markdown
Collaborator

Why don't we directly apply the flashinfer kernel directly in use place? (Now we are modifying the sgl-kernel inplace)

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 25, 2026

Why don't we directly apply the flashinfer kernel directly in use place? (Now we are modifying the sgl-kernel inplace)

I also think it's more appropriate. @pansicheng Can you do a change for this, thanks!

@pansicheng
Copy link
Copy Markdown
Collaborator Author

Why don't we directly apply the flashinfer kernel directly in use place? (Now we are modifying the sgl-kernel inplace)

I also think it's more appropriate. @pansicheng Can you do a change for this, thanks!

Fixed, PTAL

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now. Waiting for ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 25, 2026

/tag-and-rerun-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 26, 2026

@BBuf BBuf merged commit 2ad475b into sgl-project:main Feb 26, 2026
240 of 267 checks passed
klhhhhh pushed a commit to klhhhhh/sglang that referenced this pull request Feb 26, 2026
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants