use flashinfer.sampling#18696
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
cc8e318 to
f3932e5
Compare
|
It's cool to saw some performance improve. |
|
/tag-and-rerun-ci |
| if torch.any(torch.isnan(probs)): | ||
| raise ValueError("Input probs contains NaN.") | ||
| return _top_p_sampling_from_probs_internal( | ||
| return get_sampling_module().top_p_sampling_from_probs( |
There was a problem hiding this comment.
I do not recommend directly get module from flashinfer (It's not a public API). You may take a look at my implementation in mini-sglang as a reference:
https://github.com/sgl-project/mini-sglang/blob/82722ad6dc85df766278c48061d768b4117a3bd4/python/minisgl/engine/sample.py#L24-L45
|
@DarkSharpness Any other advices? |
|
Why don't we directly apply the flashinfer kernel directly in use place? (Now we are modifying the sgl-kernel inplace) |
I also think it's more appropriate. @pansicheng Can you do a change for this, thanks! |
Fixed, PTAL |
|
/tag-and-rerun-ci |
|
Merged with ci passed https://github.com/sgl-project/sglang/actions/runs/22379938350/job/64925133127?pr=18696 |
Motivation
#17865
move (external) flashinfer/csrc/sampling.cu
Call flashinfer directly from python, instead of compiling the operators into sgl_kernel
Modifications
Accuracy Tests
unittest
python -m pytest sgl-kernel/tests/test_sampling.py -sgsm8k
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci