Skip to content

einops 0.6.1 x torch.compile broken in pytorch nightlies #157417

@zou3519

Description

@zou3519
import torch
import torch.nn as nn
from einops import einsum, pack, rearrange, reduce, repeat, unpack

class TorchModuleWithOperations(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x_abc, suffix=""):
        a, b, c = x_abc.shape

        def suf(pattern):
            parts = pattern.split()
            return " ".join(
                [p if p[-1] not in "acd" else p + suffix for p in parts]
            )

        # patterns look a bit strange because names a, c, d will be modified on every run
        # by suf function
        x_abcd = repeat(x_abc, suf("a b c -> a b c 4"))
        x_abc = reduce(x_abcd, suf("a b c d -> a b c"), "min")
        x_abdc, ps = pack([x_abc] * (2 + len(suffix)), suf("a b * c"))
        x_array = unpack(
            rearrange(x_abdc, suf("a b d c -> (a b ) 1 c d")), ps, "ab one1 c *"
        )
        x1 = x_array[0] + len(x_array)
        x1 = rearrange(x1, suf("(a b ) 1 c -> a b c"), b=b)
        addition = einsum(x_abc, x_abcd, suf("a b c , a b c d -> d"))[0]
        return x1 + addition

original = TorchModuleWithOperations()
# Einops only interacts with Dynamo but we test backend="inductor" just in case
compiled = torch.compile(original, backend="eager", fullgraph=True)
for size in [10, 20, 40]:
    x = torch.rand([size, size + 1, size + 2])
    for suffix in ["", "suf1", "other_suffix"]:
        result1 = compiled(x, suffix)
        result2 = original(x.double(), suffix).float()

fails with

TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <function repeat at 0x7f1e654a2ca0>(*(FakeTe
nsor(..., size=(s48, s86, s93)), 'a b c -> a b c 4'), **{}): got TypeError('unhashable type: non-nested SymInt')

from user code:
   File "<ipython-input-1-c92c344a3624>", line 20, in forward
    x_abcd = repeat(x_abc, suf("a b c -> a b c 4"))

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For
even more developer context, set TORCH_LOGS="+dynamo"

cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @bobrenjc93 @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @amjames @Lucaskabela @jataylo @chenyang78

Metadata

Metadata

Labels

dynamo-triage-dec2025small/mid-sized dynamo tasks that we would like to be completed in the near futurehigh prioritymodule: dynamic shapesmodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions