Skip to content

[TorchScript] Failure if you script a wrapper module and then an interface-implementing submodule. #140468

@davidberard98

Description

@davidberard98

🐛 Describe the bug

Repro is below:

  • We have a wrapper module that calls an implementation submodule, and the implementation submodule is marked as an interface
  • First we torchscript the wrapper module
  • Then we torchscript the submodule.

Since the first torchscript-ing of the wrapper module saw the submodule as an interface type, it ignores the methods that are not part of the interface. Then we cache the type. Finally, when we torchscript the submodule on its own, we see the other methods and fail because the jit_type associated with this class doesn't have those methods.

import torch

@torch.jit.interface
class MyInterface(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass


class MyImplementation(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * x

    @torch.jit.export
    def add_two(self, x: torch.Tensor) -> torch.Tensor:
        return x + 2


class MyWrapper(torch.nn.Module):
    impl: MyInterface

    def __init__(self):
        super().__init__()
        self.impl = MyImplementation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.impl(x)


mod = MyWrapper()
mod_s = torch.jit.script(mod)
mod.impl = torch.jit.script(mod.impl)

error

  File "/data/users/dberard/scripts/interface_extra.py", line 31, in <module>
    mod.impl = torch.jit.script(mod.impl)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_script.py", line 1429, in script
    ret = _script_impl(
          ^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_script.py", line 1147, in _script_impl
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_recursive.py", line 557, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_recursive.py", line 679, in create_script_module_impl
    script_method = cpp_module._get_method(name)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Method 'add_two' is not defined.

Versions

main branch, CPU build

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions