Skip to content

Replace _resolve_future_token_ids with JIT kernel + platform dispatch#20976

Merged
merrymercy merged 4 commits intomainfrom
lianmin/sglang-cleanup
Mar 20, 2026
Merged

Replace _resolve_future_token_ids with JIT kernel + platform dispatch#20976
merrymercy merged 4 commits intomainfrom
lianmin/sglang-cleanup

Conversation

@merrymercy
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy commented Mar 20, 2026

Summary

  • Add a CUDA JIT kernel (resolve_future_token_ids.cuh) that resolves future token IDs in-place, templated on dtype (int32/int64)
  • Add Python JIT wrapper with @cache_once for module caching
  • Replace the single @torch.compile implementation with platform dispatch: CUDA JIT kernel on CUDA, torch.compile on HIP/AMD, plain torch on other platforms (NPU, CPU, etc.)
  • Add comprehensive unit tests covering multiple sizes, dtypes, and input patterns
  • Add benchmark comparing JIT kernel vs torch.compile vs PyTorch

Benchmark (NVIDIA H100)

Size SGL JIT Kernel (us) torch.compile (us) PyTorch (us) JIT vs torch.compile
16 2.82 3.25 12.31 1.15x faster
128 3.23 3.27 12.30 1.01x faster
1024 3.31 3.47 14.16 1.05x faster
4096 3.37 3.51 14.79 1.04x faster
16384 3.57 4.02 15.22 1.13x faster
32768 3.68 4.79 15.57 1.30x faster

Test plan

  • Run unit tests: python -m pytest python/sglang/jit_kernel/tests/test_resolve_future_token_ids.py -v -s (64/64 passed)
  • Run e2e tests: python -m pytest test/registered/core/test_srt_engine.py -x -v (8/8 passed)

🤖 Generated with Claude Code

Use a CUDA JIT kernel for resolving future token IDs on CUDA devices,
torch.compile on HIP/AMD, and plain torch on other platforms (NPU, CPU).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@merrymercy
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

merrymercy and others added 3 commits March 20, 2026 05:45
- Use AlignedVector for 128-bit vectorized loads/stores (4x int32, 2x int64)
- Replace per-element branching with branchless ternary (compiles to SELP)
- Add scalar tail loop for non-aligned remainders
- Add proper include comments per skill guide conventions
- Add input validation in Python wrapper
- Add benchmark (bench_resolve_future_token_ids.py)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove redundant is_cuda/dtype checks (TensorMatcher validates on kernel side)
- Add torch.compile baseline to benchmark for comparison

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The scalar kernel already outperforms torch.compile (1.3x at 32K elements)
without the added complexity of vectorized loads/stores.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@merrymercy merrymercy merged commit 112b628 into main Mar 20, 2026
45 of 75 checks passed
@merrymercy merrymercy deleted the lianmin/sglang-cleanup branch March 20, 2026 08:47
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…sgl-project#20976)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
…sgl-project#20976)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
…sgl-project#20976)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…sgl-project#20976)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants