[jit] Add weak script modules#12682
[jit] Add weak script modules#12682driazati wants to merge 21 commits intopytorch:masterfrom driazati:weak_mod
Conversation
zdevito
left a comment
There was a problem hiding this comment.
Good progress! I added a bunch of comments. Largest thing I see is the need for careful tests that will hit all the corner cases that can happen.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zdevito
left a comment
There was a problem hiding this comment.
I am confused by the logic for maintaining information about weak_module classes and instances. I don't see why we need WeakModuleInstance to exist and WeakModule is awkward as a named tuple, considering that it is being mutated by reallocating it. It would help to organize this functionality into a clear API of classes and functions. As one example, consider: weak_modules.get(value.__class__). This line appears in setattr of the proxy class, which has multiple responsibilities. But what it is really saying is is_weak_script_annotated(value.class). Even though that function is a one liner, it is a lot easier to understand the code if it is refactored this way. Similarly for the rest of the logic here: it helps to decompose things into small functions with meaningful names and to keep the actual data structures (in this case, the weak dict tables), hidden behind this API. That way the implementation details do not leak into client code.
| return x + self.weak_submodule(x) + self.strong_submodule(x) | ||
|
|
||
| class Strong(torch.jit.ScriptModule): | ||
| __constants__ = ['constant'] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| @torch.jit.weak_script_method | ||
| def forward(self, x): | ||
| return x * x |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Symbol::fromQualString(py::str(builtin_name)), c10::nullopt); | ||
| } | ||
|
|
||
| auto compiled_mod = |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
|
|
||
| def _try_get_weak_module(mod): | ||
| weak_mods_of_class = weak_modules.get(mod.__class__) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| def _try_get_weak_module(mod): | ||
| weak_mods_of_class = weak_modules.get(mod.__class__) | ||
| if weak_mods_of_class is not None: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return cls | ||
|
|
||
|
|
||
| def _get_weak_stubs(mod): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| BatchTensor = torch._C._jit.BatchTensor | ||
| compiled_weak_fns = weakref.WeakKeyDictionary() | ||
| weak_script_methods = weakref.WeakKeyDictionary() | ||
| weak_modules = weakref.WeakKeyDictionary() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| def _try_get_weak_module(mod): | ||
| weak_mods_of_class = weak_modules.get(mod.__class__) | ||
| if weak_mods_of_class is not None: | ||
| weak_mod = weak_mods_of_class.get(mod) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if weak_module.instances is None: | ||
| # Generate stubs and add to entry | ||
| stubs = _get_weak_stubs(mod) | ||
| weak_modules[cls] = WeakModule(weakref.WeakKeyDictionary(), stubs) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
|
|
||
| def _try_get_weak_module(mod): | ||
| weak_mods_of_class = weak_modules.get(mod.__class__) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zdevito
left a comment
There was a problem hiding this comment.
This looks much better! I had one minor thing about caching results in _make_strong. Other than that I think this is ready to go!
| """ | ||
| Converts a weak module into a subclass of ScriptModule | ||
| """ | ||
| stubs = _weak_types.get(type(mod))["method_stubs"] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Depends on #12682 ([stacked diff](https://github.com/driazati/pytorch/compare/weak_mod...driazati:mod_conv1)) * Adds tests for weak module conversion that creates a `ScriptModule` that uses the weak module and checks its graph * Adds `torch._jit_internal.weak_module` tags to modules that already work * `Sigmoid` * `Tanh` * `Hardshrink` * `PReLU` * `Softsign` * `Tanhshrink` * `PairwiseDistance` Pull Request resolved: #12966 Differential Revision: D10559557 Pulled By: driazati fbshipit-source-id: dc4bea3aa744b3c44d4fa7dceefd97e951f824d0
Summary: Adds support for weak script modules created that get compiled to `ScriptModule`s once added as a submodule of a `ScriptModule`: ```python weak_module class Test(torch.nn.Module): ... weak_script_method def forward(self, x): ... ``` Pull Request resolved: pytorch#12682 Differential Revision: D10458626 Pulled By: driazati fbshipit-source-id: 10ae23cb83cdafc4646cee58f399e14b2e60acd4
Summary: Depends on pytorch#12682 ([stacked diff](https://github.com/driazati/pytorch/compare/weak_mod...driazati:mod_conv1)) * Adds tests for weak module conversion that creates a `ScriptModule` that uses the weak module and checks its graph * Adds `torch._jit_internal.weak_module` tags to modules that already work * `Sigmoid` * `Tanh` * `Hardshrink` * `PReLU` * `Softsign` * `Tanhshrink` * `PairwiseDistance` Pull Request resolved: pytorch#12966 Differential Revision: D10559557 Pulled By: driazati fbshipit-source-id: dc4bea3aa744b3c44d4fa7dceefd97e951f824d0
Adds support for weak script modules created that get compiled to
ScriptModules once added as a submodule of aScriptModule: