Skip to content

[Feature] JIT Fused QK norm + qk norm clean up#15835

Merged
BBuf merged 5 commits intosgl-project:mainfrom
DarkSharpness:jit_qknorm
Dec 28, 2025
Merged

[Feature] JIT Fused QK norm + qk norm clean up#15835
BBuf merged 5 commits intosgl-project:mainfrom
DarkSharpness:jit_qknorm

Conversation

@DarkSharpness
Copy link
Copy Markdown
Collaborator

@DarkSharpness DarkSharpness commented Dec 25, 2025

Motivation

  1. QK Norm kernel is not efficient enough. Around 2 months ago, flashinfer has released a newer version of QK-norm, which has not been integrated into sgl-kernel. This may lead to poor performance in prefill stage.
  2. QK Norm can not fully utilize GPU bandwidth when the batch size is relatively small.
  3. Too many redundant code _apply_qk_norm.

Modifications

  1. Implement JIT fused qk norm kernel. Similar to flashinfer, we use persistent kernel.
  2. Use JIT fused qk norm whenever possible.
  3. Clean up redundant logic _apply_qk_norm. Move them to sglang/srt/models/utils.py.

Note: since our JIT kernel only support head_dim [64,128,256], this can also be moved into sgl-kernel in the future without siginificantly increase binary size.

Accuracy Tests

WIP.

Benchmarking and Profiling

E2E around 1~2% on Qwen3 models.

Kernel Benchmark Latency (us) on H200 (head_dim=128)
GQA num_kv_heads batch_size SGL AOT Kernel SGL JIT Kernel FlashInfer PyTorch
4 1 1.0 1.602216 0.967299 4.147745 2.210164
4 1 2.0 1.971132 1.129724 4.082101 2.216914
4 1 4.0 2.046463 1.161237 4.260650 2.446676
4 1 8.0 2.056453 1.177705 4.127348 2.469167
4 1 16.0 2.070655 1.210284 4.242177 2.491159
4 1 32.0 2.146091 1.235676 4.423212 2.606742
4 1 64.0 2.218982 1.257530 4.510034 2.710421
4 1 128.0 2.497092 1.374964 4.580845 2.817401
4 1 256.0 3.116057 1.546483 4.579036 3.071840
4 1 512.0 4.905310 2.016863 4.698664 3.780974
4 1 1024.0 8.032317 2.829033 5.164029 5.242347
4 1 2048.0 14.227522 3.949354 5.849227 8.328087
4 1 4096.0 26.553799 4.718069 7.889430 14.549551
4 1 8192.0 51.193776 5.929556 12.844778 26.873150
4 2 1.0 1.973113 1.135202 4.056827 2.218648
4 2 2.0 2.049635 1.174577 4.268414 2.448797
4 2 4.0 1.838493 1.186040 4.131957 2.472057
4 2 8.0 2.075420 1.223204 4.358434 2.501710
4 2 16.0 2.146743 1.237090 4.430364 2.612028
4 2 32.0 2.224000 1.275996 4.664317 2.707676
4 2 64.0 2.498599 1.381083 4.668938 2.810551
4 2 128.0 3.304765 1.554070 4.638308 3.072464
4 2 256.0 4.910514 2.033047 4.675055 3.783757
4 2 512.0 8.034118 2.827229 5.185844 5.237069
4 2 1024.0 14.051559 3.956178 5.871687 8.326787
4 2 2048.0 26.555363 4.723587 7.890730 14.554491
4 2 4096.0 51.204209 5.967548 12.810800 26.865454
4 2 8192.0 100.506665 10.681090 22.567152 51.510706
4 4 1.0 2.049338 1.179208 4.258082 2.444852
4 4 2.0 2.060563 1.178384 4.132504 2.470129
4 4 4.0 2.074426 1.224661 4.419047 2.486625
4 4 8.0 2.146800 1.243366 4.431141 2.613374
4 4 16.0 1.977443 1.275281 4.548108 2.720860
4 4 32.0 2.498298 1.382079 4.589099 2.809886
4 4 64.0 3.304652 1.558480 4.638347 3.072272
4 4 128.0 4.912402 2.039143 4.670729 3.785166
4 4 256.0 8.031794 2.863567 5.185206 5.237654
4 4 512.0 14.227032 3.968998 5.845436 8.324923
4 4 1024.0 26.555411 4.708028 7.893088 14.546558
4 4 2048.0 51.191171 5.952161 12.851992 26.866253
4 4 4096.0 100.500408 10.683963 22.568350 51.440001
4 4 8192.0 198.997016 22.859876 46.764163 101.185289
4 8 1.0 2.057232 1.187613 4.095947 2.456053
4 8 2.0 2.075319 1.224966 4.326579 2.495948
4 8 4.0 2.147293 1.235857 4.503845 2.598718
4 8 8.0 2.223401 1.274508 4.551287 2.720107
4 8 16.0 2.499761 1.382484 4.589297 2.807916
4 8 32.0 3.304863 1.581218 4.638084 3.073978
4 8 64.0 4.722954 2.025154 4.672783 3.789915
4 8 128.0 8.030740 2.837629 5.152085 5.237443
4 8 256.0 14.224144 3.968735 5.843177 8.326346
4 8 512.0 26.548503 4.736224 7.877057 14.548110
4 8 1024.0 51.189130 5.930764 12.812288 26.821492
4 8 2048.0 100.369806 10.720507 22.623180 51.433407
4 8 4096.0 198.991034 22.912703 46.883177 101.183354
4 8 8192.0 396.153278 48.456803 90.865334 200.319835
8 1 1.0 1.660214 1.134311 4.053002 2.204968
8 1 2.0 2.063375 1.152339 4.104204 2.438979
8 1 4.0 2.048348 1.184119 4.094758 2.457241
8 1 8.0 2.091486 1.199920 4.262926 2.395754
8 1 16.0 2.101410 1.259972 4.303576 2.648056
8 1 32.0 2.211657 1.256035 4.426617 2.618278
8 1 64.0 2.425739 1.355766 4.569506 2.726263
8 1 128.0 3.152905 1.532012 4.588152 3.035327
8 1 256.0 4.425606 1.902914 4.630433 3.648836
8 1 512.0 7.447579 2.679543 5.177979 5.030939
8 1 1024.0 12.978922 3.924003 5.903701 7.690500
8 1 2048.0 24.080260 4.509180 7.425713 13.256224
8 1 4096.0 46.263833 5.730878 12.149372 24.367141
8 1 8192.0 90.507927 9.801896 21.008092 46.509898
8 2 1.0 2.064506 1.150009 4.104689 2.439759
8 2 2.0 2.048290 1.183201 4.095368 2.456052
8 2 4.0 1.872091 1.213490 4.290765 2.397189
8 2 8.0 2.103361 1.260300 4.303000 2.647652
8 2 16.0 2.211406 1.266252 4.485983 2.654824
8 2 32.0 2.426277 1.353754 4.508568 2.729922
8 2 64.0 3.151445 1.525999 4.588223 3.037183
8 2 128.0 4.600527 1.904346 4.592388 3.643603
8 2 256.0 7.438590 2.676125 5.177108 5.031077
8 2 512.0 12.984543 3.914288 5.871477 7.690347
8 2 1024.0 23.866795 4.510011 7.394907 13.228737
8 2 2048.0 46.198835 5.736124 12.134384 24.368680
8 2 4096.0 90.510744 9.821377 20.995621 46.512056
8 2 8192.0 179.314180 20.053768 41.638584 91.337495
8 4 1.0 2.046420 1.175872 4.097494 2.455421
8 4 2.0 2.090495 1.210281 4.204116 2.390348
8 4 4.0 2.103140 1.260213 4.303335 2.646751
8 4 8.0 2.211287 1.273012 4.482001 2.611962
8 4 16.0 2.220915 1.362798 4.509473 2.731405
8 4 32.0 3.151757 1.533456 4.657780 3.030493
8 4 64.0 4.602847 1.910799 4.596051 3.650759
8 4 128.0 7.438587 2.716375 5.158408 5.047253
8 4 256.0 12.956142 3.926019 5.870062 7.686410
8 4 512.0 24.049866 4.509933 7.407691 13.227090
8 4 1024.0 46.203743 5.739054 12.169481 24.367745
8 4 2048.0 90.506814 9.839303 21.047479 46.510143
8 4 4096.0 179.320431 20.028844 41.554052 91.335621
8 4 8192.0 356.758100 44.002492 82.021558 180.347889
8 8 1.0 2.091602 1.217257 4.242791 2.395733
8 8 2.0 2.101548 1.260544 4.303451 2.647919
8 8 4.0 2.213135 1.268818 4.503230 2.649711
8 8 8.0 2.426232 1.357569 4.570823 2.724038
8 8 16.0 3.151602 1.516329 4.576696 3.031661
8 8 32.0 4.602788 1.915288 4.657099 3.651862
8 8 64.0 7.258208 2.713892 5.160116 5.028964
8 8 128.0 12.985743 3.924281 5.868523 7.670379
8 8 256.0 24.046876 4.501554 7.398339 13.227310
8 8 512.0 46.202775 5.739965 12.145323 24.373176
8 8 1024.0 90.510368 9.904544 21.005994 46.513264
8 8 2048.0 179.318836 20.015311 41.564332 91.198415
8 8 4096.0 356.774989 43.970684 82.019536 180.342816
8 8 8192.0 711.237124 84.883776 160.825233 357.91069

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @DarkSharpness, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Query-Key (QK) normalization within the system by introducing a JIT fused kernel. This optimization addresses inefficiencies in the existing QK norm kernel, particularly for smaller batch sizes, and streamlines the codebase by refactoring common normalization logic into a reusable utility. The changes are expected to lead to faster prefill stages and improved GPU utilization.

Highlights

  • Performance Optimization: Introduced a JIT (Just-In-Time) fused QK norm kernel to significantly improve the efficiency of QK normalization, especially for smaller batch sizes where the previous kernel was underperforming. Benchmarks on an H200 GPU show substantial latency reductions across various configurations.
  • Code Refactoring: Centralized and cleaned up redundant _apply_qk_norm logic by moving it into a new shared utility function apply_qk_norm in sglang/srt/models/utils.py. This new utility function intelligently dispatches to the JIT kernel when applicable.
  • JIT Kernel Implementation: Implemented the fused QK norm kernel in C++/CUDA (norm.cuh), utilizing persistent kernels and Programmatic Stream Serialization (PDL) for enhanced performance. The JIT kernel currently supports head dimensions of 64, 128, and 256.
  • Benchmarking Infrastructure: Added a new benchmark script (bench_qknorm.py) to compare the performance of the new SGL JIT Kernel against the existing SGL Kernel and FlashInfer's implementation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a JIT fused QK norm kernel, which significantly improves performance, especially for smaller batch sizes, as demonstrated by the provided benchmarks. The refactoring of the _apply_qk_norm logic into a shared utility function (sglang.srt.models.utils.apply_qk_norm) is a good step towards reducing redundant code and improving maintainability across different model implementations. The new CUDA C++ kernels and their Python wrappers are well-integrated. However, there is a critical issue identified in the CUDA kernel's pointer arithmetic for the k tensor, which could lead to memory access violations.

Comment thread python/sglang/jit_kernel/csrc/norm.cuh
Comment thread python/sglang/jit_kernel/csrc/norm.cuh
Comment thread python/sglang/jit_kernel/csrc/norm.cuh Outdated
Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

local precision test passed cc @BBuf

@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the kUsePDL template parameter in the JIT kernel doesn't automatically enable or disable itself based on GPU architecture?

// NOTE: we offset the k here to reduce computation cost in the kernel
const auto params = QKNormParams{
.q = q.data_ptr(),
.k = pointer::offset(k.data_ptr(), -2 * static_cast<int64_t>(num_qo_heads) * kHeadDim),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment for this line, I can't easily understand now.

N_K = 2
N_Q = 16
DEVICE = "cuda"
DTYPE = torch.bfloat16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a torch.float16 is better

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job. apporved!

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 26, 2025

/tag-and-rerun-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 26, 2025

Add a end2end model acc test?

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 28, 2025

@BBuf BBuf merged commit 8e43980 into sgl-project:main Dec 28, 2025
137 of 144 checks passed
@DarkSharpness DarkSharpness deleted the jit_qknorm branch December 28, 2025 03:55
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> None:
from sgl_kernel import rmsnorm
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put import to the top?

q = q.view(-1, head_dim)
k = k.view(-1, head_dim)

current_stream = torch.cuda.current_stream()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes i found this function has a large CPU overhead. I use

def get_current_device_stream_fast():
global cached_device_index
if cached_device_index == -1:
cached_device_index = torch.get_device_module().current_device()
return torch.get_device_module().current_stream(cached_device_index)


@cache_once
def can_use_fused_inplace_qknorm(head_dim: int) -> bool:
logger = logging.getLogger(__name__)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: could we move this line to the top of this file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem. it's just my personal preference to lazy init

YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants