Skip to content

torch.einsum is incorrectly decomposed when wrapped inside a custom op #8713

@tengyifei

Description

@tengyifei

🐛 Bug

This code

import torch
from torch import Tensor
from torch.library import custom_op

@custom_op("xla::custom_linear_forward123", schema="(Tensor input, Tensor weight) -> Tensor", mutates_args=())
def custom_linear_forward123(input: Tensor, weight: Tensor):
    return torch.einsum('...n,mn->...m', input, weight)

is a no-op wrapper around einsum. However, it changes the behavior significantly in PyTorch/XLA. The einsum is decomposed into a number of expensive transpose ops.

To Reproduce

Run the notebook https://github.com/tengyifei/playground/blob/master/aot-einsum-3.ipynb locally.

Expected behavior

The lowering of custom_linear_forward123 contains an einsum op.

Actual behavior

The lowering of custom_linear_forward123 becomes

IR {
  %0 = f32[] prim::Constant(), xla_shape=f32[]
  %1 = f32[3,3]{1,0} aten::expand(%0), xla_shape=f32[3,3]{1,0}
  %2 = f32[3,3,1]{2,1,0} aten::as_strided(%1), xla_shape=f32[3,3,1]{2,1,0}
  %3 = f32[3,3,1]{2,1,0} aten::as_strided(%2), xla_shape=f32[3,3,1]{2,1,0}
  %4 = f32[1,3,3]{2,1,0} aten::view(%3), xla_shape=f32[1,3,3]{2,1,0}
  %5 = f32[] prim::Constant(), xla_shape=f32[]
  %6 = f32[3,3]{1,0} aten::expand(%5), xla_shape=f32[3,3]{1,0}
  %7 = f32[3,3,1]{2,1,0} aten::as_strided(%6), xla_shape=f32[3,3,1]{2,1,0}
  %8 = f32[3,3,1]{2,1,0} aten::as_strided(%7), xla_shape=f32[3,3,1]{2,1,0}
  %9 = f32[1,3,3]{2,1,0} aten::view(%8), xla_shape=f32[1,3,3]{2,1,0}
  %10 = f32[1,3,3]{2,1,0} aten::matmul(%9, %4), xla_shape=f32[1,3,3]{2,1,0}
  %11 = f32[3,1,3]{2,1,0} aten::view(%10), xla_shape=f32[3,1,3]{2,1,0}
  %12 = f32[3,3,1]{2,1,0} aten::as_strided(%11), xla_shape=f32[3,3,1]{2,1,0}
  %13 = f32[3,3]{1,0} aten::view(%12), xla_shape=f32[3,3]{1,0}, ROOT=0
}

Environment

Reproducible on PyTorch/XLA 2.6 stable.

Additional context

I've updated https://github.com/tengyifei/playground/blob/master/aot-einsum-3.ipynb with dispatcher traces.

When calling torch.einsum regularly, PyTorch dispatcher prints

 [call] op=[aten::einsum], key=[AutogradXLA]

whereas when torch.einsum is wrapped in a custom op, PyTorch dispatcher prints

 [call] op=[aten::einsum], key=[XLA]

followed by a whole bunch of decomposed aten operations.

This suggests that when calling torch.einsum with the XLA dispatch key,
our registered lowerings are bypassed. Instead, some other code in PyTorch
handles it and turns the einsum into a bunch of permutes.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingloweringATen Operation lowering

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions