Replace _resolve_future_token_ids with JIT kernel + platform dispatch#20976
Merged
merrymercy merged 4 commits intomainfrom Mar 20, 2026
Merged
Replace _resolve_future_token_ids with JIT kernel + platform dispatch#20976merrymercy merged 4 commits intomainfrom
merrymercy merged 4 commits intomainfrom
Conversation
Use a CUDA JIT kernel for resolving future token IDs on CUDA devices, torch.compile on HIP/AMD, and plain torch on other platforms (NPU, CPU). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Contributor
Author
|
/tag-and-rerun-ci |
- Use AlignedVector for 128-bit vectorized loads/stores (4x int32, 2x int64) - Replace per-element branching with branchless ternary (compiles to SELP) - Add scalar tail loop for non-aligned remainders - Add proper include comments per skill guide conventions - Add input validation in Python wrapper - Add benchmark (bench_resolve_future_token_ids.py) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove redundant is_cuda/dtype checks (TensorMatcher validates on kernel side) - Add torch.compile baseline to benchmark for comparison Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The scalar kernel already outperforms torch.compile (1.3x at 32K elements) without the added complexity of vectorized loads/stores. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
DarkSharpness
approved these changes
Mar 20, 2026
Wangzheee
pushed a commit
to Wangzheee/sglang
that referenced
this pull request
Mar 21, 2026
…sgl-project#20976) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
0-693
pushed a commit
to 0-693/sglang
that referenced
this pull request
Mar 25, 2026
…sgl-project#20976) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
dutsc
pushed a commit
to dutsc/sglang
that referenced
this pull request
Mar 30, 2026
…sgl-project#20976) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JustinTong0323
pushed a commit
to JustinTong0323/sglang
that referenced
this pull request
Apr 7, 2026
…sgl-project#20976) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
yhyang201
pushed a commit
to yhyang201/sglang
that referenced
this pull request
Apr 22, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
resolve_future_token_ids.cuh) that resolves future token IDs in-place, templated on dtype (int32/int64)@cache_oncefor module caching@torch.compileimplementation with platform dispatch: CUDA JIT kernel on CUDA,torch.compileon HIP/AMD, plain torch on other platforms (NPU, CPU, etc.)Benchmark (NVIDIA H100)
Test plan
python -m pytest python/sglang/jit_kernel/tests/test_resolve_future_token_ids.py -v -s(64/64 passed)python -m pytest test/registered/core/test_srt_engine.py -x -v(8/8 passed)🤖 Generated with Claude Code