Fix XLA searchsorted side='right' for NaN values#117850
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
@googlebot rescan |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for handling NaN search values in the searchsorted operation by adding a nan_values_compare_greater parameter to BuildLowerUpperBoundOp and implementing the corresponding logic for UpperBoundOp. It also adds a unit test testNanValue to verify this behavior. The review feedback suggests optimizing the implementation by restricting the NaN comparison logic to floating-point input types using DataTypeIsFloating, thereby avoiding redundant HLO instructions for integer types.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if (nan_values_compare_greater) { | ||
| // Match eager searchsorted side='right' behavior for NaN search values. | ||
| // A NaN search value is placed after all entries in the sorted input row. | ||
| auto value_is_nan = xla::Compare(values_reshaped, values_reshaped, {}, | ||
| xla::ComparisonDirection::kNe); | ||
| comparison = xla::Or(comparison, value_is_nan); | ||
| } |
There was a problem hiding this comment.
Since tf.searchsorted supports both floating-point and integer types, we should avoid generating unnecessary HLO instructions (Compare and Or) for integer types, as integers cannot be NaN. We can restrict this logic to only run when the input type is a floating-point type using DataTypeIsFloating.
| if (nan_values_compare_greater) { | |
| // Match eager searchsorted side='right' behavior for NaN search values. | |
| // A NaN search value is placed after all entries in the sorted input row. | |
| auto value_is_nan = xla::Compare(values_reshaped, values_reshaped, {}, | |
| xla::ComparisonDirection::kNe); | |
| comparison = xla::Or(comparison, value_is_nan); | |
| } | |
| if (nan_values_compare_greater && DataTypeIsFloating(ctx->input_type(0))) { | |
| // Match eager searchsorted side='right' behavior for NaN search values. | |
| // A NaN search value is placed after all entries in the sorted input row. | |
| auto value_is_nan = xla::Compare(values_reshaped, values_reshaped, {}, | |
| xla::ComparisonDirection::kNe); | |
| comparison = xla::Or(comparison, value_is_nan); | |
| } |
References
- Identify unnecessary memory usage or redundant computations. (link)
dmiltr3
left a comment
There was a problem hiding this comment.
Thanks for the contribution! The changes look great and correctly align XLA's behavior with eager TensorFlow for NaN search values in both side='left' and side='right'.
I appreciate that you included the unit tests covering left and right sides across the floating-point types (float16, float32, float64, and bfloat16).
- Old logic: Scans physical `BufferAllocation`s and only counts those where `IsPreallocatedTempBuffer()` is true. This misses scratch buffers that the compiler overlays onto live-out (output) allocations to save memory. - New logic: Walks the executed `Thunk` sequence and sums the sizes of all logical `BufferUse::Scratch` slices, accurately capturing scratch usage regardless of physical overlay optimizations. Example (32MB matmul scratch overlaid on a 157MB live-out output buffer): - Old logic: Returns 0 bytes (skips the live-out allocation). - New logic: Returns 32MB (correctly extracts the scratch thunk use). Also had to a const way to Walk Thunks #### Why this is needed The current logic relies on the `BufferAssignments` being alive, since the `buffer` variable which holds an `HloValue` is owned by the `BufferAssignments` brass. I'm in the process of removing the `BufferAssignment` from the executable since it it cannot be re-created when loading an AOT binary, and its not really needed. So we need to get rid of this implicit dependency before doing so. FUTURE_COPYBARA_INTEGRATE_REVIEW=#117850 from praneethhere:fix-xla-searchsorted-nan-upper-bound 7f73af1 PiperOrigin-RevId: 937384420
- Old logic: Scans physical `BufferAllocation`s and only counts those where `IsPreallocatedTempBuffer()` is true. This misses scratch buffers that the compiler overlays onto live-out (output) allocations to save memory. - New logic: Walks the executed `Thunk` sequence and sums the sizes of all logical `BufferUse::Scratch` slices, accurately capturing scratch usage regardless of physical overlay optimizations. Example (32MB matmul scratch overlaid on a 157MB live-out output buffer): - Old logic: Returns 0 bytes (skips the live-out allocation). - New logic: Returns 32MB (correctly extracts the scratch thunk use). Also had to a const way to Walk Thunks #### Why this is needed The current logic relies on the `BufferAssignments` being alive, since the `buffer` variable which holds an `HloValue` is owned by the `BufferAssignments` brass. I'm in the process of removing the `BufferAssignment` from the executable since it it cannot be re-created when loading an AOT binary, and its not really needed. So we need to get rid of this implicit dependency before doing so. FUTURE_COPYBARA_INTEGRATE_REVIEW=#117850 from praneethhere:fix-xla-searchsorted-nan-upper-bound 7f73af1 PiperOrigin-RevId: 937384420
Fixes #117804
This fixes an inconsistency between eager TensorFlow and XLA for
tf.searchsorted(..., side="right")when the search value isNaNThe XLA lowering for
UpperBoundwas usingvalue >= sorted_inputto count thematching positions. For
NaN, that comparison is false for every sorted inputvalue, so XLA returned
0Eager TensorFlow places a
NaNsearch value at the end of the row forside="right". This change keeps the existing comparison logic, but treatsNaNsearch values as greater than all sorted input values forUpperBoundI also added a regression test covering both
side="left"andside="right"Tested with:
bazel test //tensorflow/compiler/tests:searchsorted_op_test_cpu --repo_env=HERMETIC_PYTHON_VERSION=3.12 --macos_sdk_version=15.5 --jobs=4 --local_ram_resources=4096