Skip to content

torch.fx.symbolic_trace() loses module class information #66335

@datumbox

Description

@datumbox

🐛 Bug

Tracing a model with torch.fx.symbolic_trace() loses all submodule class information. This can be a problem for FX-based tools that rely on class information. See pytorch/vision#4549 for a prototype example.

To Reproduce

from torch import fx, nn


class SomeBlock(nn.Sequential):
    def __init__(self, in_chan, out_chan):
        super().__init__(nn.Conv2d(in_chan, out_chan, 3), nn.BatchNorm2d(1), nn.ReLU())
        self.myname = self.__class__


class SomeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([SomeBlock(3, 32), SomeBlock(32, 64), SomeBlock(64, 128)])

    def forward(self, x):
        for l in self.layers:
            x = l(x)
        return x


model = SomeModel()
gm = fx.symbolic_trace(model)

previous_class = model.layers[0].__class__  # <class '__main__.SomeBlock'>
new_class = gm._modules["layers"]._modules['0'].__class__ # <class 'torch.nn.modules.module.Module'>
assert model.layers[0].__class__ == gm._modules["layers"]._modules['0'].__class__, f"{previous_class} != {new_class}"

Returns:

Traceback (most recent call last):
  File "./deleteme.py", line 26, in <module>
    assert model.layers[0].__class__ == gm._modules["layers"]._modules['0'].__class__, f"{previous_class} != {new_class}"
AssertionError: <class '__main__.SomeBlock'> != <class 'torch.nn.modules.module.Module'>

Expected behavior

The two classes should match. The submodule should maintain its class information.

Environment

  • PyTorch Version (e.g., 1.0): pytorch-1.11.0.dev20211005
  • OS (e.g., Linux): macOS
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): n/a
  • Python version: 3.8
  • CUDA/cuDNN version: n/a
  • GPU models and configuration: n/a
  • Any other relevant information: n/a

cc @ezyang @SherlockNoMad

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions