Skip to content

[jit] Add weak script modules#12682

Closed
driazati wants to merge 21 commits intopytorch:masterfrom
driazati:weak_mod
Closed

[jit] Add weak script modules#12682
driazati wants to merge 21 commits intopytorch:masterfrom
driazati:weak_mod

Conversation

@driazati
Copy link
Copy Markdown
Contributor

Adds support for weak script modules created that get compiled to ScriptModules once added as a submodule of a ScriptModule:

@weak_module
class Test(torch.nn.Module):
	...
	@weak_script_method
	def forward(self, x):
		...

Copy link
Copy Markdown
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good progress! I added a bunch of comments. Largest thing I see is the need for careful tests that will hit all the corner cases that can happen.

Comment thread test/test_jit.py Outdated

This comment was marked as off-topic.

Comment thread torch/csrc/jit/script/init.cpp Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated
Comment thread torch/jit/__init__.py Outdated

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by the logic for maintaining information about weak_module classes and instances. I don't see why we need WeakModuleInstance to exist and WeakModule is awkward as a named tuple, considering that it is being mutated by reallocating it. It would help to organize this functionality into a clear API of classes and functions. As one example, consider: weak_modules.get(value.__class__). This line appears in setattr of the proxy class, which has multiple responsibilities. But what it is really saying is is_weak_script_annotated(value.class). Even though that function is a one liner, it is a lot easier to understand the code if it is refactored this way. Similarly for the rest of the logic here: it helps to decompose things into small functions with meaningful names and to keep the actual data structures (in this case, the weak dict tables), hidden behind this API. That way the implementation details do not leak into client code.

Comment thread test/test_jit.py
return x + self.weak_submodule(x) + self.strong_submodule(x)

class Strong(torch.jit.ScriptModule):
__constants__ = ['constant']

This comment was marked as off-topic.

Comment thread test/test_jit.py Outdated

@torch.jit.weak_script_method
def forward(self, x):
return x * x

This comment was marked as off-topic.

Comment thread torch/csrc/jit/script/init.cpp Outdated
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
}

auto compiled_mod =

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated


def _try_get_weak_module(mod):
weak_mods_of_class = weak_modules.get(mod.__class__)

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated

def _try_get_weak_module(mod):
weak_mods_of_class = weak_modules.get(mod.__class__)
if weak_mods_of_class is not None:

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated
return cls


def _get_weak_stubs(mod):

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated
BatchTensor = torch._C._jit.BatchTensor
compiled_weak_fns = weakref.WeakKeyDictionary()
weak_script_methods = weakref.WeakKeyDictionary()
weak_modules = weakref.WeakKeyDictionary()

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated
def _try_get_weak_module(mod):
weak_mods_of_class = weak_modules.get(mod.__class__)
if weak_mods_of_class is not None:
weak_mod = weak_mods_of_class.get(mod)

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated
if weak_module.instances is None:
# Generate stubs and add to entry
stubs = _get_weak_stubs(mod)
weak_modules[cls] = WeakModule(weakref.WeakKeyDictionary(), stubs)

This comment was marked as off-topic.

Comment thread torch/jit/__init__.py Outdated


def _try_get_weak_module(mod):
weak_mods_of_class = weak_modules.get(mod.__class__)

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much better! I had one minor thing about caching results in _make_strong. Other than that I think this is ready to go!

Comment thread torch/jit/__init__.py
"""
Converts a weak module into a subclass of ScriptModule
"""
stubs = _weak_types.get(type(mod))["method_stubs"]

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@driazati driazati closed this Oct 22, 2018
@driazati driazati reopened this Oct 22, 2018
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@driazati driazati closed this Oct 22, 2018
@driazati driazati reopened this Oct 22, 2018
@driazati driazati closed this Oct 22, 2018
@driazati driazati reopened this Oct 22, 2018
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Oct 25, 2018
Summary:
Depends on #12682 ([stacked diff](https://github.com/driazati/pytorch/compare/weak_mod...driazati:mod_conv1))

* Adds tests for weak module conversion that creates a `ScriptModule` that uses the weak module and checks its graph
* Adds `torch._jit_internal.weak_module` tags to modules that already work
  * `Sigmoid`
  * `Tanh`
  * `Hardshrink`
  * `PReLU`
  * `Softsign`
  * `Tanhshrink`
  * `PairwiseDistance`
Pull Request resolved: #12966

Differential Revision: D10559557

Pulled By: driazati

fbshipit-source-id: dc4bea3aa744b3c44d4fa7dceefd97e951f824d0
@ezyang ezyang added the merged label Jun 25, 2019
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Adds support for weak script modules created that get compiled to `ScriptModule`s once added as a submodule of a `ScriptModule`:

```python
weak_module
class Test(torch.nn.Module):
	...
	weak_script_method
	def forward(self, x):
		...
```
Pull Request resolved: pytorch#12682

Differential Revision: D10458626

Pulled By: driazati

fbshipit-source-id: 10ae23cb83cdafc4646cee58f399e14b2e60acd4
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Depends on pytorch#12682 ([stacked diff](https://github.com/driazati/pytorch/compare/weak_mod...driazati:mod_conv1))

* Adds tests for weak module conversion that creates a `ScriptModule` that uses the weak module and checks its graph
* Adds `torch._jit_internal.weak_module` tags to modules that already work
  * `Sigmoid`
  * `Tanh`
  * `Hardshrink`
  * `PReLU`
  * `Softsign`
  * `Tanhshrink`
  * `PairwiseDistance`
Pull Request resolved: pytorch#12966

Differential Revision: D10559557

Pulled By: driazati

fbshipit-source-id: dc4bea3aa744b3c44d4fa7dceefd97e951f824d0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants