Fix(jit): support rmsnorm for hidden_size in {64, 128, 256}#20661
Fix(jit): support rmsnorm for hidden_size in {64, 128, 256}#20661BBuf merged 5 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
| output_ptr = pointer::offset<Float>(output, i * output_stride); | ||
| output_vec = norm::apply_norm_warp<kDim>(input_vec, weight_vec, eps); | ||
| } | ||
| gmem.store(output_ptr, output_vec); |
There was a problem hiding this comment.
This gmem.store(output_ptr, output_vec); should be inside a for loop, and the if statement inside the for loop is meaningless.
There was a problem hiding this comment.
Perhaps this is correct, as the for loop writes the token processed in the previous for loop each time, and the token processed in the last for loop must be written back separately. However, this makes the code less readable, and I think it would be better to process a token within a single loop.
There was a problem hiding this comment.
Key takeaways
- hidden_size=8192: Sequential is consistently and significantly faster — up to 19% at large batch. This is the most impactful regime (e.g. Llama/Qwen models with hidden_size=8192).
- hidden_size 3072/5120: Results are mixed and small (±3–5%), largely within noise margins.
- Small batch (≤128): Pipeline wins by ~1–3% — negligible in practice since absolute latency is already ~4µs.
- The pipeline pattern adds register pressure (keeping output_ptr/output_vec live across the loop + a conditional branch), which clearly hurts occupancy at hidden_size=8192.
So replace the pipeline kernels with the sequential variants. The sequential pattern is simpler, more readable, and strictly better for the most common large-model configurations.
|
Please provide complete unit test results (screenshots or logs) once the above comments are resolved. |
Yes, it did as |
b62da97 to
672df76
Compare
HydraQYH
left a comment
There was a problem hiding this comment.
It seems the kernel implementation is OK. And I don't think some unit tests are necessary, but adding them won't hurt. If you insist they are necessary, you can keep them.
| torch.testing.assert_close(r_jit, r_ref, rtol=1e-2, atol=1e-2) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( |
There was a problem hiding this comment.
There's only one question: are these new unit tests really necessary?
There was a problem hiding this comment.
I will remove that
|
/tag-and-rerun-ci |



Motivation
jit_rmsnorm silently failed for hidden_size ∈ {64, 128, 256} and hidden_size = 16384 during benchmarking on B200. The root cause:
Both cases caused nvcc to exit with status 2, flooding stderr with full compilation logs for every failing (hidden_size, batch_size, dtype) combination.
Modifications
csrc/elementwise/rmsnorm.cuhpython/sglang/jit_kernel/norm.pypython/sglang/jit_kernel/tests/test_norm_jit.pyAccuracy Tests
test_rmsnorm_jit validates correctness against flashinfer reference for all supported hidden sizes including the newly added {64, 128, 256}, across bf16/fp16 and batch_size ∈ {1, 19, 99, 989}.
Benchmarking and Profiling
Measured on B200 (hidden_size=64, batch_size=1, bf16), jit_rmsnorm with the new warp norm kernel is the fastest among all providers:

Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci