Introduce apply_xla_patch_to_nn_linear and test that in a scan#8739
Merged
Introduce apply_xla_patch_to_nn_linear and test that in a scan#8739
Conversation
c53acea to
7d834f7
Compare
In order to propagate sharding annotations in 2D sharding, linear layers should be implemented with einsum instead of tranposes/reshapes. Additionally, they need to continue to function inside scan/scan_layers. For this to work we need three pieces: - I added a `apply_xla_patch_to_nn_linear` function to replace the implementation of `nn.Linear` with einsum (calling XLAPatchedLinear). - The XLAPatchedLinear implementation should be wrapped in torch custom ops. That's because AOTAutograd used by scan will decompose all einsums into transposes/reshapes, unless we use `@custom_op` to mark a function as opaque to AOTAutograd. - Even after wrapping them with `@custom_op`, the einsum is still decomposed into transposes/reshapes due to #8713. That's a bug/PyTorch limitation. To workaround this, I added a `_xla_einsum` C++ function that directly builds an einsum given XLA tensors, skipping over any PyTorch dispatcher complexity. Added a test that demonstrates how `nn.Linear` layers by default flattens any non-contracting dims, and how we could avoid that with `apply_xla_patch_to_nn_linear`.
7d834f7 to
2cf50c8
Compare
zpcore
reviewed
Feb 25, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In order to propagate sharding annotations in 2D sharding, linear layers should be implemented with einsum instead of tranposes/reshapes. Additionally, they need to continue to function inside scan/scan_layers.
For this to work we need three pieces:
apply_xla_patch_to_nn_linearfunction to replace the implementation ofnn.Linearwith einsum (calling XLAPatchedLinear).@custom_opto mark a function as opaque to AOTAutograd.@custom_op, the einsum is still decomposed into transposes/reshapes due totorch.einsum is incorrectly decomposed when wrapped inside a custom op #8713. That's a bug/PyTorch limitation. To workaround this, I added a
_xla_einsumC++ function that directly builds an einsum given XLA tensors, skipping over any PyTorch dispatcher complexity.Added a test that demonstrates how
nn.Linearlayers by default flattens any non-contracting dims, and how we could avoid that withapply_xla_patch_to_nn_linear.