🐛 Describe the bug
Repro is below:
- We have a wrapper module that calls an implementation submodule, and the implementation submodule is marked as an interface
- First we torchscript the wrapper module
- Then we torchscript the submodule.
Since the first torchscript-ing of the wrapper module saw the submodule as an interface type, it ignores the methods that are not part of the interface. Then we cache the type. Finally, when we torchscript the submodule on its own, we see the other methods and fail because the jit_type associated with this class doesn't have those methods.
import torch
@torch.jit.interface
class MyInterface(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class MyImplementation(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * x
@torch.jit.export
def add_two(self, x: torch.Tensor) -> torch.Tensor:
return x + 2
class MyWrapper(torch.nn.Module):
impl: MyInterface
def __init__(self):
super().__init__()
self.impl = MyImplementation()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.impl(x)
mod = MyWrapper()
mod_s = torch.jit.script(mod)
mod.impl = torch.jit.script(mod.impl)
error
File "/data/users/dberard/scripts/interface_extra.py", line 31, in <module>
mod.impl = torch.jit.script(mod.impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/dberard/pytorch/torch/jit/_script.py", line 1429, in script
ret = _script_impl(
^^^^^^^^^^^^^
File "/data/users/dberard/pytorch/torch/jit/_script.py", line 1147, in _script_impl
return torch.jit._recursive.create_script_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/dberard/pytorch/torch/jit/_recursive.py", line 557, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/dberard/pytorch/torch/jit/_recursive.py", line 679, in create_script_module_impl
script_method = cpp_module._get_method(name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Method 'add_two' is not defined.
Versions
main branch, CPU build
cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel
🐛 Describe the bug
Repro is below:
Since the first torchscript-ing of the wrapper module saw the submodule as an interface type, it ignores the methods that are not part of the interface. Then we cache the type. Finally, when we torchscript the submodule on its own, we see the other methods and fail because the jit_type associated with this class doesn't have those methods.
error
Versions
main branch, CPU build
cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel