Skip to content

Opt jit qknorm_across_heads cuda kernel#21503

Merged
BBuf merged 2 commits intomainfrom
opt_qknorm_across_heads
Mar 27, 2026
Merged

Opt jit qknorm_across_heads cuda kernel#21503
BBuf merged 2 commits intomainfrom
opt_qknorm_across_heads

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Mar 27, 2026

Motivation

Follow #18073

The old kernel handled both q and k inside one CTA, which kept too much
state live at the same time:

  • q
  • k
  • q_weight
  • k_weight
  • separate output vectors
  • dual reduction buffers in shared memory

The new kernel still performs the work in a single launch, but splits the work
with grid.y = 2:

  • blockIdx.y == 0: normalize q
  • blockIdx.y == 1: normalize k

This reduces per-thread live state and shrinks the shared reduction buffer from
two lanes to one lane.

On H200 with shape (batch_size=2048, hidden_dim=8192):

  • registers/thread: 48 -> 26
  • static shared memory/block: 256 B -> 128 B
  • theoretical occupancy: 50% -> 100%
  • achieved occupancy: 45.25% -> 88.17%
  • achieved active warps/SM: 28.96 -> 56.43
图片 a6cbf9c5-1176-4552-a185-768207d7f63e 7a72f7ac-48fb-4ed0-9334-66f4b6d77289

Microbench results

H200, bf16:

Shape Baseline Optimized Speedup
(256, 1024) 0.0207 ms 0.0194 ms 1.0711x
(1024, 4096) 0.0688 ms 0.0598 ms 1.1506x
(2048, 8192) 0.1786 ms 0.1592 ms 1.1218x

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@BBuf
Copy link
Copy Markdown
Collaborator Author

BBuf commented Mar 27, 2026

/tag-and-rerun-ci

Comment thread python/sglang/jit_kernel/csrc/elementwise/qknorm_across_heads.cuh Outdated
@BBuf BBuf merged commit e8d46f1 into main Mar 27, 2026
40 of 70 checks passed
@BBuf BBuf deleted the opt_qknorm_across_heads branch March 27, 2026 05:30
satyamk7054 pushed a commit to satyamk7054/sglang that referenced this pull request Apr 3, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants