[Auto Sync] Update batch_invariant_ops.py (20251109)#12916
Conversation
Co-authored-by: Stefan He <hebiaobuaa@gmail.com>
Summary of ChangesHello @merrymercy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request synchronizes an update to Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request updates batch_invariant_ops.py to ensure correct dtype initialization for accumulators within the _log_softmax_kernel Triton kernel. The changes are functionally correct and improve numerical stability by initializing max_val and sum_exp with the appropriate tensor dtypes. My review includes suggestions to make these initializations more concise and readable by using tl.full, which simplifies the code by removing special handling for the first block and providing a more direct way to create scalar tensors.
| # Load first block to infer dtype and initialize max_val with correct type | ||
| col_idx_init = tl.arange(0, BLOCK_SIZE) | ||
| mask_init = col_idx_init < n_cols | ||
| vals_init = tl.load( | ||
| row_start_ptr + col_idx_init, mask=mask_init, other=-float("inf") | ||
| ) | ||
| max_val = tl.max(vals_init) |
There was a problem hiding this comment.
While loading the first block to infer the dtype is a valid approach, you can achieve the same result more concisely by directly initializing max_val as a scalar tensor with the correct dtype using tl.full. This avoids special-casing the first block and simplifies the loop structure. See my suggestion on line 327 for the corresponding loop change.
| # Load first block to infer dtype and initialize max_val with correct type | |
| col_idx_init = tl.arange(0, BLOCK_SIZE) | |
| mask_init = col_idx_init < n_cols | |
| vals_init = tl.load( | |
| row_start_ptr + col_idx_init, mask=mask_init, other=-float("inf") | |
| ) | |
| max_val = tl.max(vals_init) | |
| # Step 1: Find maximum value in the row for numerical stability | |
| max_val = tl.full((), -float("inf"), dtype=input_ptr.dtype.element_ty) |
| max_val = tl.max(vals_init) | ||
|
|
||
| # Continue with remaining blocks | ||
| for col_offset in range(BLOCK_SIZE, n_cols, BLOCK_SIZE): |
There was a problem hiding this comment.
| # Initialize sum_exp with correct dtype by using tl.sum on a zero vector | ||
| sum_exp = tl.sum(tl.zeros([1], dtype=max_val.dtype)) |
There was a problem hiding this comment.
A more direct and readable way to initialize a scalar tensor to zero with a specific dtype is to use tl.full. This avoids the intermediate step of creating a single-element array and summing it.
| # Initialize sum_exp with correct dtype by using tl.sum on a zero vector | |
| sum_exp = tl.sum(tl.zeros([1], dtype=max_val.dtype)) | |
| sum_exp = tl.full((), 0.0, dtype=max_val.dtype) |
ch-wan
left a comment
There was a problem hiding this comment.
Do you think we can follow gemini's suggestion?
Sync changes from commit
2605e162.Files Changed:
Author: Stefan He hebiaobuaa@gmail.com
Motivation
Support FP64 Log Softmax for Deterministic Mode
This is an automated PR created by scripts/copy_from_oss.py.