Skip to content

Replace clamp_position with JIT kernel + platform dispatch#20999

Merged
merrymercy merged 5 commits intomainfrom
lianmin/clamp-position-jit
Mar 22, 2026
Merged

Replace clamp_position with JIT kernel + platform dispatch#20999
merrymercy merged 5 commits intomainfrom
lianmin/clamp-position-jit

Conversation

@merrymercy
Copy link
Copy Markdown
Contributor

Summary

  • Add a CUDA JIT kernel (clamp_position.cuh) for computing decode positions: clamp(seq_lens - 1, min=0), templated on dtype (int32/int64)
  • Add Python JIT wrapper with @cache_once for module caching
  • Replace the single @torch.compile implementation in forward_batch_info.py with platform dispatch: CUDA JIT kernel on CUDA, torch.compile on HIP/AMD, plain torch on other platforms (NPU, CPU, etc.)
  • Add comprehensive unit tests (64 cases) and benchmark

Benchmark (NVIDIA H100)

Size SGL JIT Kernel (us) torch.compile (us) PyTorch (us) JIT vs torch.compile
16 1.35 1.56 3.36 1.15x faster
128 1.55 1.61 3.40 1.04x faster
1024 1.59 1.71 3.42 1.08x faster
4096 1.63 2.20 3.46 1.35x faster
16384 1.66 4.27 3.52 2.57x faster
32768 1.66 6.94 3.57 4.18x faster

Test plan

  • Run unit tests: python -m pytest python/sglang/jit_kernel/tests/test_clamp_position.py -v -s (64/64 passed)
  • Run e2e tests: python -m pytest test/registered/core/test_srt_engine.py -x -v (8/8 passed)

🤖 Generated with Claude Code

@merrymercy
Copy link
Copy Markdown
Contributor Author

/tag-and-rerun-ci

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of the clamp_position operation by introducing a specialized CUDA JIT kernel. It refactors the existing implementation to dynamically select the most optimized approach based on the execution platform, ensuring efficiency across various hardware configurations. The changes are supported by new tests and benchmarks demonstrating the performance benefits.

Highlights

  • Performance Optimization: A new CUDA JIT kernel has been introduced for the clamp_position operation, which computes clamp(seq_lens - 1, min=0), offering significant performance improvements, especially for larger input sizes.
  • Platform Dispatch: The clamp_position function now uses platform-specific implementations: the new CUDA JIT kernel on CUDA, torch.compile on HIP/AMD, and a plain PyTorch implementation on other platforms like NPU or CPU.
  • JIT Kernel Integration: A Python JIT wrapper with @cache_once has been added to efficiently load and execute the CUDA JIT kernel for clamp_position.
  • Testing and Benchmarking: Comprehensive unit tests (64 cases) and a dedicated benchmark script have been added to validate the correctness and measure the performance gains of the new clamp_position implementation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a custom JIT CUDA kernel for clamp_position to improve performance, along with a platform dispatch mechanism. The changes are well-structured, including the kernel, Python wrappers, comprehensive unit tests, and benchmarks demonstrating significant performance gains. My main feedback addresses a data type inconsistency between the new CUDA implementation and the implementations for other platforms, which could lead to subtle bugs. I've suggested changes to ensure the CUDA path consistently returns int64 tensors, aligning it with other platforms and the original behavior.

Comment thread python/sglang/jit_kernel/clamp_position.py
Comment thread python/sglang/jit_kernel/tests/test_clamp_position.py
@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Mar 20, 2026
merrymercy and others added 2 commits March 20, 2026 10:07
- Add a CUDA JIT kernel for clamp_position (clamp(seq_lens - 1, min=0))
- Use JIT kernel on CUDA, torch.compile on HIP/AMD, plain torch elsewhere
- Add unit tests (64 cases) and benchmark

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Documents the full workflow: launch server, validate accuracy,
capture Chrome-compatible trace, kill server, report profile path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@merrymercy merrymercy force-pushed the lianmin/clamp-position-jit branch from dca3ffc to 666b601 Compare March 20, 2026 10:07
@merrymercy merrymercy merged commit 76e4a86 into main Mar 22, 2026
100 of 117 checks passed
@merrymercy merrymercy deleted the lianmin/clamp-position-jit branch March 22, 2026 04:26
OrangeRedeng pushed a commit to OrangeRedeng/sglang that referenced this pull request Mar 22, 2026
…ct#20999)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
…ct#20999)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
…ct#20999)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…ct#20999)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation jit-kernel run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant