Skip to content

Fix(jit): support rmsnorm for hidden_size in {64, 128, 256}#20661

Merged
BBuf merged 5 commits intosgl-project:mainfrom
Johnsonms:fix/jit-rmsnorm-unsupported-hidden-size
Mar 23, 2026
Merged

Fix(jit): support rmsnorm for hidden_size in {64, 128, 256}#20661
BBuf merged 5 commits intosgl-project:mainfrom
Johnsonms:fix/jit-rmsnorm-unsupported-hidden-size

Conversation

@Johnsonms
Copy link
Copy Markdown
Contributor

@Johnsonms Johnsonms commented Mar 16, 2026

Motivation

jit_rmsnorm silently failed for hidden_size ∈ {64, 128, 256} and hidden_size = 16384 during benchmarking on B200. The root cause:

  • RMSNormKernel only implemented the CTA norm path (hidden_size > 256), so instantiating it for small hidden sizes triggered a static_assert failure ("Hidden size invalid for RMSNorm") at JIT compile time
  • hidden_size = 16384 exceeded the CTA kernel's supported range (≤ 8192), triggering "Unsupported norm configuration"

Both cases caused nvcc to exit with status 2, flooding stderr with full compilation logs for every failing (hidden_size, batch_size, dtype) combination.

Modifications

csrc/elementwise/rmsnorm.cuh

  • Add rmsnorm_warp kernel: uses tile::Memory::warp() and apply_norm_warp() — one warp (32 threads) per token, no shared memory needed
  • Add RMSNormWarpKernel<kDim, kUsePDL, DType> struct for kDim ∈ {64, 128, 256}, symmetric to the existing RMSNormKernel

python/sglang/jit_kernel/norm.py

  • Add _is_supported_rmsnorm_hidden_size(hidden_size): returns True for warp sizes {64, 128, 256} and CTA sizes (multiples of 256, in range (256, 8192])
  • Add _rmsnorm_kernel_class(hidden_size): returns "RMSNormWarpKernel" or "RMSNormKernel" based on hidden size
  • _jit_rmsnorm_module now routes to the correct kernel class
  • rmsnorm() raises a clean RuntimeError for unsupported hidden sizes (e.g. 16384) instead of falling through to noisy nvcc failures

python/sglang/jit_kernel/tests/test_norm_jit.py

  • Extend RMSNORM_HIDDEN_SIZES to include [64, 128, 256]
  • Add test_rmsnorm_hidden_size_support: unit tests for _is_supported_rmsnorm_hidden_size covering edge cases {0, 64, 128, 256, 512, 8192, 16384}
  • Add test_rmsnorm_kernel_dispatch: verifies correct kernel class selection for all size categories
  • Add test_rmsnorm_rejects_unsupported_hidden_size: verifies clean RuntimeError is raised for {0, 16384}

Accuracy Tests

test_rmsnorm_jit validates correctness against flashinfer reference for all supported hidden sizes including the newly added {64, 128, 256}, across bf16/fp16 and batch_size ∈ {1, 19, 99, 989}.

Benchmarking and Profiling

Measured on B200 (hidden_size=64, batch_size=1, bf16), jit_rmsnorm with the new warp norm kernel is the fastest among all providers:
image

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Comment thread python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh Outdated
Copy link
Copy Markdown
Collaborator

@HydraQYH HydraQYH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you run unit tests yourself?

Comment thread python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh
Copy link
Copy Markdown
Collaborator

@HydraQYH HydraQYH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these unit tests are necessary; tests for these functionalities are already included in the kernel's unit tests.

Comment thread python/sglang/jit_kernel/tests/test_norm_jit.py
Comment thread python/sglang/jit_kernel/tests/test_norm_jit.py
output_ptr = pointer::offset<Float>(output, i * output_stride);
output_vec = norm::apply_norm_warp<kDim>(input_vec, weight_vec, eps);
}
gmem.store(output_ptr, output_vec);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gmem.store(output_ptr, output_vec); should be inside a for loop, and the if statement inside the for loop is meaningless.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this is correct, as the for loop writes the token processed in the previous for loop each time, and the token processed in the last for loop must be written back separately. However, this makes the code less readable, and I think it would be better to process a token within a single loop.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image Key takeaways
  1. hidden_size=8192: Sequential is consistently and significantly faster — up to 19% at large batch. This is the most impactful regime (e.g. Llama/Qwen models with hidden_size=8192).
  2. hidden_size 3072/5120: Results are mixed and small (±3–5%), largely within noise margins.
  3. Small batch (≤128): Pipeline wins by ~1–3% — negligible in practice since absolute latency is already ~4µs.
  4. The pipeline pattern adds register pressure (keeping output_ptr/output_vec live across the loop + a conditional branch), which clearly hurts occupancy at hidden_size=8192.

So replace the pipeline kernels with the sequential variants. The sequential pattern is simpler, more readable, and strictly better for the most common large-model configurations.

Comment thread python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh
@HydraQYH
Copy link
Copy Markdown
Collaborator

Please provide complete unit test results (screenshots or logs) once the above comments are resolved.

@Johnsonms Johnsonms marked this pull request as draft March 16, 2026 13:33
@Johnsonms
Copy link
Copy Markdown
Contributor Author

Johnsonms commented Mar 22, 2026

Please provide complete unit test results (screenshots or logs) once the above comments are resolved.

/sgl-workspace/sglang/python/sglang/jit_kernel# python tests/test_norm_jit.py

image

root@gpu-dp-nwrpk-b25k7:/sgl-workspace/sglang/python/sglang/jit_kernel# python benchmark/bench_rmsnorm.py

image

Summary:
image

Overall conclusion
SGL JIT is faster than SGL AOT across nearly all configurations, with gains ranging from ~6% at small batch to ~24% at large batch. The advantage grows with batch size and is consistent across all hidden sizes. Only at hidden_size=1536 with batch=1 does AOT have a slight edge (~4%), which is within noise margin. The JIT sequential kernel is the clear winner.

@Johnsonms
Copy link
Copy Markdown
Contributor Author

Have you run unit tests yourself?

Yes, it did as https://github.com/sgl-project/sglang/pull/20661#issue-4080337611, but no screenshot for this

@Johnsonms Johnsonms force-pushed the fix/jit-rmsnorm-unsupported-hidden-size branch from b62da97 to 672df76 Compare March 22, 2026 05:42
@Johnsonms Johnsonms marked this pull request as ready for review March 22, 2026 05:43
@Johnsonms Johnsonms requested a review from HydraQYH March 22, 2026 05:44
Copy link
Copy Markdown
Collaborator

@HydraQYH HydraQYH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the kernel implementation is OK. And I don't think some unit tests are necessary, but adding them won't hurt. If you insist they are necessary, you can keep them.

torch.testing.assert_close(r_jit, r_ref, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only one question: are these new unit tests really necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove that

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Mar 23, 2026

/tag-and-rerun-ci

@BBuf BBuf merged commit 777edb6 into sgl-project:main Mar 23, 2026
92 of 121 checks passed
adityavaid pushed a commit to adityavaid/sglang that referenced this pull request Mar 24, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants