Skip to content

Introduce apply_xla_patch_to_nn_linear and test that in a scan#8739

Merged
tengyifei merged 1 commit intomasterfrom
yifeit/workaround-einsum
Feb 25, 2025
Merged

Introduce apply_xla_patch_to_nn_linear and test that in a scan#8739
tengyifei merged 1 commit intomasterfrom
yifeit/workaround-einsum

Conversation

@tengyifei
Copy link
Copy Markdown
Collaborator

In order to propagate sharding annotations in 2D sharding, linear layers should be implemented with einsum instead of tranposes/reshapes. Additionally, they need to continue to function inside scan/scan_layers.

For this to work we need three pieces:

  • I added a apply_xla_patch_to_nn_linear function to replace the implementation of nn.Linear with einsum (calling XLAPatchedLinear).
  • The XLAPatchedLinear implementation should be wrapped in torch custom ops. That's because AOTAutograd used by scan will decompose all einsums into transposes/reshapes, unless we use @custom_op to mark a function as opaque to AOTAutograd.
  • Even after wrapping them with @custom_op, the einsum is still decomposed into transposes/reshapes due to
    torch.einsum is incorrectly decomposed when wrapped inside a custom op #8713. That's a bug/PyTorch limitation. To workaround this, I added a _xla_einsum C++ function that directly builds an einsum given XLA tensors, skipping over any PyTorch dispatcher complexity.

Added a test that demonstrates how nn.Linear layers by default flattens any non-contracting dims, and how we could avoid that with apply_xla_patch_to_nn_linear.

@tengyifei tengyifei force-pushed the yifeit/workaround-einsum branch from c53acea to 7d834f7 Compare February 24, 2025 22:56
@tengyifei tengyifei changed the title Support einsum layers in a scan Introduce apply_xla_patch_to_nn_linear and test that in a scan Feb 24, 2025
@tengyifei tengyifei marked this pull request as ready for review February 24, 2025 23:22
In order to propagate sharding annotations in 2D sharding, linear layers
should be implemented with einsum instead of tranposes/reshapes.
Additionally, they need to continue to function inside scan/scan_layers.

For this to work we need three pieces:

- I added a `apply_xla_patch_to_nn_linear` function to replace the
implementation of `nn.Linear` with einsum (calling XLAPatchedLinear).
- The XLAPatchedLinear implementation should be wrapped in torch custom
ops. That's because AOTAutograd used by scan will decompose all einsums
into transposes/reshapes, unless we use `@custom_op` to mark a function
as opaque to AOTAutograd.
- Even after wrapping them with `@custom_op`, the einsum is still
decomposed into transposes/reshapes due to
#8713. That's a bug/PyTorch
limitation. To workaround this, I added a `_xla_einsum` C++ function
that directly builds an einsum given XLA tensors, skipping over any
PyTorch dispatcher complexity.

Added a test that demonstrates how `nn.Linear` layers by default
flattens any non-contracting dims, and how we could avoid that with
`apply_xla_patch_to_nn_linear`.
@tengyifei tengyifei force-pushed the yifeit/workaround-einsum branch from 7d834f7 to 2cf50c8 Compare February 24, 2025 23:48
Comment thread torch_xla/distributed/spmd/xla_sharding.py
Copy link
Copy Markdown
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@tengyifei tengyifei merged commit 6f020aa into master Feb 25, 2025
Copy link
Copy Markdown
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants