Skip to content

Add Batch‑Invariant RMSNorm#12144

Merged
Fridge003 merged 7 commits intosgl-project:mainfrom
zyzshishui:rmsnorm
Oct 29, 2025
Merged

Add Batch‑Invariant RMSNorm#12144
Fridge003 merged 7 commits intosgl-project:mainfrom
zyzshishui:rmsnorm

Conversation

@zyzshishui
Copy link
Copy Markdown
Contributor

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zyzshishui, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new batch-invariant RMS normalization capability, designed to enhance the performance and determinism of the SGLang inference engine. By implementing a custom Triton kernel for RMSNorm and integrating it into the existing layer normalization framework, the changes ensure consistent normalization behavior, particularly when operating in a batch-invariant mode. This is a foundational step towards more robust and predictable model execution.

Highlights

  • New Triton Kernel for RMSNorm: A new _rms_norm_kernel function, implemented using Triton, has been added to batch_invariant_ops.py to perform efficient Root Mean Square (RMS) normalization on the GPU.
  • RMSNorm Functionality: A Python wrapper rms_norm is introduced, which leverages the new Triton kernel to apply RMS normalization to input tensors.
  • Batch-Invariant RMSNorm Wrapper: A dedicated rms_norm_batch_invariant function has been added to ensure deterministic and batch-invariant behavior of RMS normalization, crucial for consistent model inference.
  • Integration into Layer Normalization: The layernorm.py module now integrates the rms_norm_batch_invariant function. It will be used in forward_cuda when the batch-invariant mode is active and no residual connection is present.
  • Refactored Deterministic Inference Logic: The explicit check for the SGLANG_ENABLE_DETERMINISTIC_INFERENCE environment variable has been removed from layernorm.py, streamlining the logic for deterministic operations under the new batch-invariant mechanism.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a batch-invariant RMS Normalization using a Triton kernel, which is a valuable addition for deterministic inference. The implementation is sound, but I have two main suggestions. First, and most importantly, unit tests are missing for the new functionality. It is critical to add tests to verify correctness and the batch-invariant property, as also indicated by the unchecked item in the PR description's checklist. Second, the performance of the Triton kernel can be improved by using triton.autotune to select an optimal BLOCK_SIZE instead of the current hardcoded value. I've left a specific suggestion for this.

Comment thread python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
Comment thread python/sglang/srt/layers/layernorm.py Outdated
Copy link
Copy Markdown
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ nit above

@fzyzcjy fzyzcjy added the run-ci label Oct 26, 2025
@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented Oct 26, 2025

btw ci fails

image

if is_batch_invariant_mode_enabled():
if (
residual is not None
or get_global_server_args().rl_on_policy_target == "fsdp"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this condition cause difference behavior between fsdp and sglang?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in true-on-policy-with-fsdp, we should (at least currently) use the forward_native b/c the fused triton kernel may give diff numeric values

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the fused triton kernel, have we tested whether its batch invariant?

@Fridge003 Fridge003 merged commit 0380ca8 into sgl-project:main Oct 29, 2025
63 of 71 checks passed
@zyzshishui zyzshishui deleted the rmsnorm branch November 23, 2025 19:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants