Skip to content

[jit] Add @script decorator for Modules#22328

Closed
driazati wants to merge 5 commits intomasterfrom
driazati/decclass
Closed

[jit] Add @script decorator for Modules#22328
driazati wants to merge 5 commits intomasterfrom
driazati/decclass

Conversation

@driazati
Copy link
Copy Markdown
Contributor

This lets you put @torch.jit.script on an nn.Module, which then will
wrap any instantiations of the module with a script call. This lets
users make classes that will always be used as ScriptModules without
having to call script() every time

This lets you put `@torch.jit.script` on an `nn.Module`, which then will
wrap any instantiations of the module with a `script` call. This lets
users make classes that will always be used as `ScriptModule`s without
having to call `script()` every time
@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 28, 2019
@driazati driazati requested review from suo and zdevito and removed request for zdevito June 28, 2019 18:03
Copy link
Copy Markdown
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this approach is legit. @suo has argued that we should not, in fact, have a decorator and instead rename script to compile.


# Since we don't return the module itself from this wrapper, its `__init__`
# is never automatically called, so we have to do it explicitly here
nn_module.__init__(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think is not correct to make __new__ call __init__ and will break assumptions in other parts of python (e.g. things that try to copy the object by calling __new__).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of the decorator is that it compiles your module automatically, which it can't do if it hasn't been initialized. Does it make sense to break those cases since they wouldn't work with ScriptModules anyways?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe? What do people typically do in class decorators. It seems like there is at least some precedent because it doesn't all init if another object is returned.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class decorators seem pretty rare, the associated PEP is also light on details. Most usages seem to be people returning some kind of wrapper class type from the decorator, but returning anything other than the class type itself breaks super(T, self) calls (they have to be changed to super(type(self), self) which makes that a non-starter for us.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not wrap the __init__ instead?

@driazati
Copy link
Copy Markdown
Contributor Author

I think it's good since it makes @torch.jit.script more consistent, namely that you can add it as a decorator on types that inherit from object but not nn.Module, and for a function you can add @torch.jit.script or call torch.jit.script() yourself, but this is different for modules. I agree that it's becoming a pretty overloaded name in that it does too many things through 1 entry point and it's not exactly clear / easily describable how it can be used. Maybe we should re-name the decorator used for user defined classes to something like @torch.jit.class.

@suo
Copy link
Copy Markdown
Member

suo commented Jul 11, 2019

If the problem you're trying to solve is to avoid having to call the recursive scripting func all the time, why can't we just recommend something like

class _MyMod(torch.nn.Module):
    ...

MyMod = torch.compile(_MyMod)

I don't think it's a big deal either way, but we should minimize the number of ways to do the same thing I think.

@suo
Copy link
Copy Markdown
Member

suo commented Jul 11, 2019

I think it's good since it makes @torch.jit.script more consistent, namely that you can add it as a decorator on types that inherit from object but not nn.Module, and for a function you can add @torch.jit.script or call torch.jit.script() yourself, but this is different for modules.

That's a pretty good argument, I'm fine with it either way I think.

@driazati
Copy link
Copy Markdown
Contributor Author

If the problem you're trying to solve is to avoid having to call the recursive scripting func all the time, why can't we just recommend something like

class _MyMod(torch.nn.Module):
    ...

MyMod = torch.compile(_MyMod)

I don't think it's a big deal either way, but we should minimize the number of ways to do the same thing I think.

Having a function like that is the same as having a decorator (it just calls the function with the decorated type under the hood), similar to how you can use @script on functions or pass a Python function to script() already.

@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Sep 8, 2019

I would really really hope that this can be updated and get through some day!

@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Sep 8, 2019

For me, the following simple patch works:

import functools

import torch
import torch.nn as nn


def compile(kind='script', *args, **kwargs):
    if isinstance(kind, type) and issubclass(kind, nn.Module):
        raise RuntimeError('Use @torch.jit.compile(kind, ...) as decorator')

    if callable(kind):
        compiler = kind 
    elif kind == 'script':
        compiler = torch.jit.script
    else:
        assert kind == 'trace'
        compiler = torch.jit.trace

    class CompiledModuleMeta(type(nn.Module)):
        def __call__(cls, *ctor_args, **ctor_kwargs):
            return torch.jit.script(super(CompiledModuleMeta, cls).__call__(*ctor_args, **ctor_kwargs), *args, **kwargs)

    def decorator(cls):
        return CompiledModuleMeta(cls.__name__, (cls,), {})

    return decorator


assert not hasattr(torch.jit, 'compile')
torch.jit.compile = compile

@torch.jit.compile()
class MLP(nn.Module):
    r"""my mlp"""
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(3, 4), nn.ReLU(True), nn.Linear(4, 3))

    def forward(self, x):
        return self.model(x)

@driazati
Copy link
Copy Markdown
Contributor Author

@ssnl We’re going with the contract everything under script is compiled or at least looked at by the compiler, but that doesn’t hold for modules so this API isn’t going to get landed

@driazati driazati closed this Sep 11, 2019
@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Sep 11, 2019

@driazati That is reasonable, but results in a really ugly API. Could another API with a different name be added?

@facebook-github-bot facebook-github-bot deleted the driazati/decclass branch July 13, 2020 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants