Deprecate explicit shape promotion in PyTorch/XLA#6121
Merged
Conversation
chunnienc
pushed a commit
to chunnienc/xla
that referenced
this pull request
Dec 14, 2023
golechwierowicz
pushed a commit
that referenced
this pull request
Jan 12, 2024
bhavya01
pushed a commit
that referenced
this pull request
Apr 22, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Proposal
Currently, both PyTorch/XLA and XLA handle implicit broadcasting for static/bounded-dynamic shapes. The implicit broadcasting in PyTorch/XLA’s codebase is not uniform: sometimes rely on their own implementation and sometimes on XLA’s implementation. The goal of this document is to investigate removal of implicit-broadcasting support in PyTorch/XLA so as to have a single source of truth in XLA.
Current state of Broadcasting in PyTorch/XLA vs XLA
Implicit broadcasting, for static and bounded dynamic dimensions, are supported in both PyTorch/XLA and XLA codebase.
PyTorch/XLA
XlaHelpers::Promote is used to do the implicit broadcasting using the following utility functions.
However, the handling of implicit broadcasting in PyTorch/XLA codebase is not uniform: some are explicit (e.g. max and arith via XlaHelpers::Promote), while others rely on XlaBuilder’s implicit broadcasting (e.g. in Relu xla::Max is called without any promotion).
XLA
HLO currently handles implicit broadcasting for binary operations, ternary operation, and map for static shapes and bounded dynamic shapes in three steps:
Here are some important points to note
XlaHelpers::Promotebefore calling the XLaBuilder APIs, for broadcastable operations, ensuring the operands shapes are already promoted and thereby preventing the XLA’s support for implicit broadcasting unexercised.broadcasting_dimensionspecification (refer to xla-broadcasting-principles).Design Proposal
Given that the HLO/StableHLO emitted via implicit broadcasting in PyTorch/XLA and XLA are semantically identical (refer to Appendix), we propose to handle broadcasting only in XLA with the following changes
XlaHelpers::PromoteShapesto do implicit broadcasting.XlaHelpers::Promotesuch that the following XLA client API calls (like xla::Min, xla::Atan2 etc.) accept broadcasting dimensions.Appendix
Examples demonstrating the current state
Next we demonstrate the current state of implicit broadcasting in PyTorch/XLA and XLA codebase. To extract the XLA behavior we need to bypass the implicit broadcasting in PyTorch/XLA code using the patch.
Same rank broadcasting
PyTorch code
Broadcasting via PyTorch/XLA
Broadcasting via XLA
Scalar Broadcasting
PyTorch code
Broadcasting via PyTorch/XLA
Broadcasting via XLA
Different rank broadcasting
Pytorch code
Broadcasting via PyTorch/XLA
Broadcasting via XLA
XLA fails as
broadcast_dimensionshas to be specified when ranks are different.