[JIT] Plumb type annotations through script compilation#9405
[JIT] Plumb type annotations through script compilation#9405jamesr66a wants to merge 1 commit intopytorch:masterfrom
Conversation
ce55be3 to
f88188c
Compare
f88188c to
8a01c6b
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zdevito
left a comment
There was a problem hiding this comment.
Cool -- I have a bunch of organization comments to keep the complexity of this change down. Right now it adds too many different pathways to the compiler and I have some suggestions about how to reduce that.
| DefAndTypes(Def def, std::vector<TypePtr> arg_types, TypePtr return_type) | ||
| : def(std::move(def)), arg_types(arg_types), return_type(return_type) {} | ||
| Def def; | ||
| std::vector<TypePtr> arg_types; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| bool allow_varargs; | ||
| }; | ||
|
|
||
| struct DefAndTypes { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/ | ||
| std::shared_ptr<SugaredValue> self /* if non-null, the first argument to each def, is bound to this value */ | ||
| std::shared_ptr<SugaredValue> self, /* if non-null, the first argument to each def, is bound to this value */ | ||
| bool pure_func=false /* If true, expect a single def which will become the 'forward' method on the module */ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| mod = ScriptModule() | ||
| rcb = createResolutionCallback(_frames_up + 1) | ||
| ast = get_jit_ast(fn) | ||
| arg_types, ret_type = annotations.get_signature(fn, ast.num_params(), 0) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| graph = _script_graph(fn, _frames_up=_frames_up + 1) | ||
| mod = ScriptModule() | ||
| mod._create_method_from_graph('forward', graph) | ||
| _script_pure_function(mod, fn, _frames_up=_frames_up + 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto& name = (*it).ident().name(); | ||
| arguments.push_back({name, DynamicType::get()}); | ||
| TypePtr arg_type = DynamicType::get(); | ||
| if (def_and_types.arg_types.size()) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| ast = get_jit_ast(fn) | ||
| arg_types, return_type = annotations.get_signature(fn, ast.num_params() - 1, 0) | ||
| # Dumb handling for `self` | ||
| if len(arg_types) == ast.num_params(): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| torch._C.ScriptModule.__init__(self) | ||
| original_init(self, *args, **kwargs) | ||
| asts = [m.ast for m in methods] | ||
| defs = [torch._C.DefAndTypes(m.ast, m.arg_types, m.return_type) for m in methods] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| for(Def def : definitions) { | ||
| const std::string& name = def.name().name(); | ||
| for(DefAndTypes def : definitions) { | ||
| const std::string& name = pure_func ? "forward" : def.def.name().name(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| std::vector<TypePtr> flattened_return_types; | ||
| if (def_and_types.return_type) { | ||
| if (def_and_types.return_type->kind() == TypeKind::TupleType) { | ||
| const auto &tuple_type_elmts = def_and_types.return_type->cast<TupleType>()->elements(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| wrap_list(r, std::move(body))); | ||
| })); | ||
| })) | ||
| .def("num_params", [](Def& self) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| FunctionTable table; | ||
| JIT_ASSERT(definitions.size() == resolvers.size()); | ||
| if (pure_func) { | ||
| JIT_ASSERT(definitions.size() == 1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| bool allow_varargs; | ||
| }; | ||
|
|
||
| struct DefAndTypes { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
|
|
||
| def _script_graph(fn, _frames_up=0): | ||
| def _compile_fn(fn, _frames_up=0): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| _jit_script_compile_pure_fn(mod, torch._C.DefAndTypes(ast, arg_types, ret_type), rcb) | ||
|
|
||
|
|
||
| def _script_graph(fn, _frames_up=0): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| m.def("_jit_script_compile", [](Def def, ResolutionCallback rcb) { | ||
| return compileFunction(def, pythonResolver(rcb)); | ||
| m.def("_jit_script_compile_pure_fn", [](Module &m, DefAndTypes def, ResolutionCallback rcb) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Superseded by #9547 |
Summary: Supersedes pytorch#9405 Pull Request resolved: pytorch#9547 Reviewed By: zdevito Differential Revision: D8900327 Pulled By: jamesr66a fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
Summary: Supersedes pytorch#9405 Pull Request resolved: pytorch#9547 Reviewed By: zdevito Differential Revision: D8900327 Pulled By: jamesr66a fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
No description provided.