import torch
from torch import Tensor
from torch.library import custom_op
@custom_op("xla::custom_linear_forward123", schema="(Tensor input, Tensor weight) -> Tensor", mutates_args=())
def custom_linear_forward123(input: Tensor, weight: Tensor):
return torch.einsum('...n,mn->...m', input, weight)
IR {
%0 = f32[] prim::Constant(), xla_shape=f32[]
%1 = f32[3,3]{1,0} aten::expand(%0), xla_shape=f32[3,3]{1,0}
%2 = f32[3,3,1]{2,1,0} aten::as_strided(%1), xla_shape=f32[3,3,1]{2,1,0}
%3 = f32[3,3,1]{2,1,0} aten::as_strided(%2), xla_shape=f32[3,3,1]{2,1,0}
%4 = f32[1,3,3]{2,1,0} aten::view(%3), xla_shape=f32[1,3,3]{2,1,0}
%5 = f32[] prim::Constant(), xla_shape=f32[]
%6 = f32[3,3]{1,0} aten::expand(%5), xla_shape=f32[3,3]{1,0}
%7 = f32[3,3,1]{2,1,0} aten::as_strided(%6), xla_shape=f32[3,3,1]{2,1,0}
%8 = f32[3,3,1]{2,1,0} aten::as_strided(%7), xla_shape=f32[3,3,1]{2,1,0}
%9 = f32[1,3,3]{2,1,0} aten::view(%8), xla_shape=f32[1,3,3]{2,1,0}
%10 = f32[1,3,3]{2,1,0} aten::matmul(%9, %4), xla_shape=f32[1,3,3]{2,1,0}
%11 = f32[3,1,3]{2,1,0} aten::view(%10), xla_shape=f32[3,1,3]{2,1,0}
%12 = f32[3,3,1]{2,1,0} aten::as_strided(%11), xla_shape=f32[3,3,1]{2,1,0}
%13 = f32[3,3]{1,0} aten::view(%12), xla_shape=f32[3,3]{1,0}, ROOT=0
}
Reproducible on PyTorch/XLA 2.6 stable.
[call] op=[aten::einsum], key=[AutogradXLA]
[call] op=[aten::einsum], key=[XLA]
followed by a whole bunch of decomposed aten operations.
🐛 Bug
This code
is a no-op wrapper around
einsum. However, it changes the behavior significantly in PyTorch/XLA. The einsum is decomposed into a number of expensive transpose ops.To Reproduce
Run the notebook https://github.com/tengyifei/playground/blob/master/aot-einsum-3.ipynb locally.
Expected behavior
The lowering of
custom_linear_forward123contains aneinsumop.Actual behavior
The lowering of
custom_linear_forward123becomesEnvironment
Reproducible on PyTorch/XLA 2.6 stable.
Additional context
I've updated https://github.com/tengyifei/playground/blob/master/aot-einsum-3.ipynb with dispatcher traces.
When calling
torch.einsumregularly, PyTorch dispatcher printswhereas when
torch.einsumis wrapped in a custom op, PyTorch dispatcher printsfollowed by a whole bunch of decomposed aten operations.
This suggests that when calling
torch.einsumwith theXLAdispatch key,our registered lowerings are bypassed. Instead, some other code in PyTorch
handles it and turns the einsum into a bunch of permutes.