Skip to content

[Auto Sync] Update batch_invariant_ops.py (20251109)#12916

Merged
merrymercy merged 1 commit intomainfrom
sync-2605e162-20251109
Nov 10, 2025
Merged

[Auto Sync] Update batch_invariant_ops.py (20251109)#12916
merrymercy merged 1 commit intomainfrom
sync-2605e162-20251109

Conversation

@merrymercy
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy commented Nov 9, 2025

Sync changes from commit 2605e162.

Files Changed:

  • python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py

Author: Stefan He hebiaobuaa@gmail.com


Motivation

Support FP64 Log Softmax for Deterministic Mode

This is an automated PR created by scripts/copy_from_oss.py.

Co-authored-by: Stefan He <hebiaobuaa@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @merrymercy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request synchronizes an update to batch_invariant_ops.py, specifically refining the _log_softmax_kernel function. The core objective of these changes is to enhance the numerical stability and type consistency of key variables within the Triton kernel by ensuring they are initialized with appropriate data types derived from the input data.

Highlights

  • Type-safe Initialization for max_val: The max_val variable in the _log_softmax_kernel is now initialized by loading the first block of data to infer its data type, enhancing numerical stability and type consistency within the Triton kernel.
  • Correct sum_exp Data Type Initialization: The sum_exp variable in the _log_softmax_kernel is now initialized using tl.sum(tl.zeros([1], dtype=max_val.dtype)) to ensure it has the correct data type, aligning with Triton's type handling best practices.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates batch_invariant_ops.py to ensure correct dtype initialization for accumulators within the _log_softmax_kernel Triton kernel. The changes are functionally correct and improve numerical stability by initializing max_val and sum_exp with the appropriate tensor dtypes. My review includes suggestions to make these initializations more concise and readable by using tl.full, which simplifies the code by removing special handling for the first block and providing a more direct way to create scalar tensors.

Comment on lines +318 to +324
# Load first block to infer dtype and initialize max_val with correct type
col_idx_init = tl.arange(0, BLOCK_SIZE)
mask_init = col_idx_init < n_cols
vals_init = tl.load(
row_start_ptr + col_idx_init, mask=mask_init, other=-float("inf")
)
max_val = tl.max(vals_init)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While loading the first block to infer the dtype is a valid approach, you can achieve the same result more concisely by directly initializing max_val as a scalar tensor with the correct dtype using tl.full. This avoids special-casing the first block and simplifies the loop structure. See my suggestion on line 327 for the corresponding loop change.

Suggested change
# Load first block to infer dtype and initialize max_val with correct type
col_idx_init = tl.arange(0, BLOCK_SIZE)
mask_init = col_idx_init < n_cols
vals_init = tl.load(
row_start_ptr + col_idx_init, mask=mask_init, other=-float("inf")
)
max_val = tl.max(vals_init)
# Step 1: Find maximum value in the row for numerical stability
max_val = tl.full((), -float("inf"), dtype=input_ptr.dtype.element_ty)

max_val = tl.max(vals_init)

# Continue with remaining blocks
for col_offset in range(BLOCK_SIZE, n_cols, BLOCK_SIZE):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following the simplification of max_val initialization in the preceding comment, the loop should now iterate over all blocks starting from offset 0.

Suggested change
for col_offset in range(BLOCK_SIZE, n_cols, BLOCK_SIZE):
for col_offset in range(0, n_cols, BLOCK_SIZE):

Comment on lines +338 to +339
# Initialize sum_exp with correct dtype by using tl.sum on a zero vector
sum_exp = tl.sum(tl.zeros([1], dtype=max_val.dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A more direct and readable way to initialize a scalar tensor to zero with a specific dtype is to use tl.full. This avoids the intermediate step of creating a single-element array and summing it.

Suggested change
# Initialize sum_exp with correct dtype by using tl.sum on a zero vector
sum_exp = tl.sum(tl.zeros([1], dtype=max_val.dtype))
sum_exp = tl.full((), 0.0, dtype=max_val.dtype)

Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can follow gemini's suggestion?

@ch-wan ch-wan self-assigned this Nov 9, 2025
@merrymercy merrymercy merged commit 9ea2c68 into main Nov 10, 2025
74 of 79 checks passed
@merrymercy merrymercy deleted the sync-2605e162-20251109 branch November 10, 2025 09:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants