Skip to content

[codex] Optimize hidden-size 512 RMSNorm dispatch#24710

Merged
BBuf merged 2 commits into
sgl-project:mainfrom
BBuf:codex/h100-rmsnorm-half-kernel
May 19, 2026
Merged

[codex] Optimize hidden-size 512 RMSNorm dispatch#24710
BBuf merged 2 commits into
sgl-project:mainfrom
BBuf:codex/h100-rmsnorm-half-kernel

Conversation

@BBuf

@BBuf BBuf commented May 8, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Dispatch hidden size 512 RMSNorm to the existing RMSNormHalfKernel; 1024/1536 stay on the generic CTA path.
  • Add a single-warp fast path in rmsnorm_cta_double / rmsnorm_cta_wide that bypasses the CTA shared-memory reduction when kNumWarps == 1.
  • Update the RMSNorm dispatch unit test.

H200 Correctness

Environment:

  • Host/container: ion8-h200, sglang_bbuf
  • Device: NVIDIA H200, CUDA_VISIBLE_DEVICES=4 for unit tests and CUDA_VISIBLE_DEVICES=2/3 for same-shape main/PR benchmarking
  • Code under test: local merge-preview with current origin/main (3b3af13d2)

Validation passed:

CUDA_VISIBLE_DEVICES=4 PYTHONPATH=python SGLANG_IS_IN_CI=1 \
TVM_FFI_CACHE_DIR=/tmp/tvm_cache_pr24710_h200 \
pytest -q python/sglang/jit_kernel/tests/test_rmsnorm.py -q

Result: 63 tests passed. Additional BF16 benchmark correctness checks against a PyTorch reference produced matching errors on main and PR: max_abs_error was 0.0, 0.00390625, or 0.0078125 depending on shape.

H200 Benchmark

CUDA event median, current origin/main -> this PR:

hidden_size batch_size main this PR change main class PR class
512 16 9.824 us 9.568 us +2.61% RMSNormKernel RMSNormHalfKernel
512 32 9.856 us 9.664 us +1.95% RMSNormKernel RMSNormHalfKernel
1024 16 9.824 us 9.632 us +1.95% RMSNormKernel RMSNormKernel
1024 32 9.856 us 9.696 us +1.62% RMSNormKernel RMSNormKernel
1536 16 9.888 us 9.600 us +2.91% RMSNormKernel RMSNormKernel
1536 32 9.984 us 9.664 us +3.21% RMSNormKernel RMSNormKernel
2048 16 9.888 us 9.600 us +2.91% RMSNormHalfKernel RMSNormHalfKernel
2048 32 9.856 us 9.696 us +1.62% RMSNormHalfKernel RMSNormHalfKernel
16384 16 10.176 us 9.984 us +1.89% RMSNormHalfKernel RMSNormHalfKernel
16384 32 10.208 us 10.080 us +1.25% RMSNormHalfKernel RMSNormHalfKernel

The target 512 shape dispatches to RMSNormHalfKernel and preserves correctness. Other rows are included as guard shapes.

Validation

  • python -m compileall -q python/sglang/jit_kernel/norm.py python/sglang/jit_kernel/tests/test_rmsnorm.py python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh
  • Custom H200 RMSNorm BF16 benchmark/correctness script.
  • git diff --check origin/main...HEAD

CI States

Latest PR Test (Base): ❌ Run #25919276112
Latest PR Test (Extra): ⚠️ Not enabled -- add run-ci-extra label to opt in.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@BBuf BBuf marked this pull request as ready for review May 9, 2026 00:57
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@yuan-luo

yuan-luo commented May 9, 2026

Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label May 9, 2026

@yuan-luo yuan-luo left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

== 512 is a single-warp sweet spot. LGTM.

@BBuf

BBuf commented May 15, 2026

Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@BBuf

BBuf commented May 19, 2026

Copy link
Copy Markdown
Collaborator Author

@BBuf BBuf merged commit 2424303 into sgl-project:main May 19, 2026
269 of 333 checks passed
Shunkangz pushed a commit to Shunkangz/sglang that referenced this pull request May 27, 2026
@BBuf BBuf deleted the codex/h100-rmsnorm-half-kernel branch June 2, 2026 12:11
alphabetc1 pushed a commit to alphabetc1/sglang that referenced this pull request Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants