Skip to content

aten::sym_size is not using torch._ops.OpOverload in FX graph #97201

@titaiwangms

Description

@titaiwangms

🐛 Describe the bug

aten::sym_size has two overloads, but none of them are used. In FX graph, OverloadPacket itself seems to be used. Not sure if this is intended, or it's a bug, but it seems to be inconsistent to other ops.

import torch
from torch.fx.experimental.proxy_tensor import make_fx

class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.flatten(x, start_dim=2, end_dim=3)

x = torch.randn(3, 5, 4, 5)
m = make_fx(MyModule(), tracing_mode="symbolic")(x)

for node in m.graph.nodes:
    if isinstance(node.target, torch._ops.OpOverloadPacket):
        print(type(node.target))  # <class 'torch._ops.OpOverloadPacket'>
        print(node.target.overloads())  # ['default', 'int']

Versions

Versions of relevant libraries:
[pip3] mypy==0.960
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.1
[pip3] pytorch==2.1.0a0+git3e6bde0
[pip3] torchaudio==0.13.0.dev20220912+cpu
[pip3] torchvision==0.15.0a0+511924c
[conda] blas 1.0 mkl
[conda] cpuonly 2.0 0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-include 2022.1.0 h06a4308_224
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.1 py39h6c91a56_0
[conda] numpy-base 1.23.1 py39ha15fc14_0
[conda] pytorch 2.1.0a0+git3e6bde0 dev_0
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torchaudio 0.13.0.dev20220912+cpu pypi_0 pypi
[conda] torchvision 0.15.0a0+511924c pypi_0 pypi

cc @ezyang @BowenBao

Metadata

Metadata

Assignees

Labels

module: dynamic shapestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions