[model transform] tuple to arglist jit pass#36093
[model transform] tuple to arglist jit pass#36093ajyu wants to merge 1 commit intopytorch:masterfrom
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D20313673 |
💊 CircleCI build failures summary and remediationsAs of commit a5b6d65 (more details on the Dr. CI page): ✅ None of the build failures appear to be your fault 💚
🚧 1 upstream failure:These were probably caused by upstream breakages:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 9 times. |
Summary: Pull Request resolved: pytorch#36093 Unwrap any tuples (including NamedTuples) in the module forward function input list to be arglist. 1. Supports multiple tuple inputs, and traces their use through CallMethods and TupleIndex 2. Does not unwrap inner use of other tuples that did not show up in the original toplevel graph inputs We work from the ScriptModule level instead of the Graph level because: 1. If the ScriptModule was previously called with the original set of inputs, the GraphExecutor caches the ExecutionPlan (specifically, ArgumentSpecCreator is derived from the Graph and type check the inputs passed in) 2. Since we are changing this graph's inputs, we clone the module and clear the GraphExecutor. Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this. Test Plan: buck test caffe2/torch/fb/model_transform:signature_translation_test Todo: Verify use in bento Untranslated graph: ``` > graph(%self : __torch__.test_jit.SparseNNWrapper, > %inputs.1 : NamedTuple(dense : Tensor, sparse : Dict(int, Tensor))): > %2 : __torch__.test_jit.SparseNN = prim::GetAttr[name="main_module"](%self) > %4 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23 > return (%4) ``` Translated graph: ``` > graph(%self : __torch__.test_jit.___torch_mangle_1.SparseNNWrapper, > %inputs.1_0 : Tensor, > %inputs.1_1 : Dict(int, Tensor)): > %2 : __torch__.test_jit.___torch_mangle_2.SparseNN = prim::GetAttr[name="main_module"](%self) > %3 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1_0, %inputs.1_1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23 > return (%3) ``` Reviewed By: houseroad Differential Revision: D20313673 fbshipit-source-id: cd99a7bf9181095a07d97a574993a91ca21da099
|
This pull request was exported from Phabricator. Differential Revision: D20313673 |
|
This pull request has been merged in aac36a8. |
Summary: Pull Request resolved: pytorch#36093 Unwrap any tuples (including NamedTuples) in the module forward function input list to be arglist. 1. Supports multiple tuple inputs, and traces their use through CallMethods and TupleIndex 2. Does not unwrap inner use of other tuples that did not show up in the original toplevel graph inputs We work from the ScriptModule level instead of the Graph level because: 1. If the ScriptModule was previously called with the original set of inputs, the GraphExecutor caches the ExecutionPlan (specifically, ArgumentSpecCreator is derived from the Graph and type check the inputs passed in) 2. Since we are changing this graph's inputs, we clone the module and clear the GraphExecutor. Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this. Test Plan: buck test caffe2/torch/fb/model_transform:signature_translation_test Todo: Verify use in bento Untranslated graph: ``` > graph(%self : __torch__.test_jit.SparseNNWrapper, > %inputs.1 : NamedTuple(dense : Tensor, sparse : Dict(int, Tensor))): > %2 : __torch__.test_jit.SparseNN = prim::GetAttr[name="main_module"](%self) > %4 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23 > return (%4) ``` Translated graph: ``` > graph(%self : __torch__.test_jit.___torch_mangle_1.SparseNNWrapper, > %inputs.1_0 : Tensor, > %inputs.1_1 : Dict(int, Tensor)): > %2 : __torch__.test_jit.___torch_mangle_2.SparseNN = prim::GetAttr[name="main_module"](%self) > %3 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1_0, %inputs.1_1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23 > return (%3) ``` Reviewed By: houseroad Differential Revision: D20313673 fbshipit-source-id: fddd07c9537dc8b6f480a14d697bea10ecc74470
Summary: Pull Request resolved: pytorch#36093 Unwrap any tuples (including NamedTuples) in the module forward function input list to be arglist. 1. Supports multiple tuple inputs, and traces their use through CallMethods and TupleIndex 2. Does not unwrap inner use of other tuples that did not show up in the original toplevel graph inputs We work from the ScriptModule level instead of the Graph level because: 1. If the ScriptModule was previously called with the original set of inputs, the GraphExecutor caches the ExecutionPlan (specifically, ArgumentSpecCreator is derived from the Graph and type check the inputs passed in) 2. Since we are changing this graph's inputs, we clone the module and clear the GraphExecutor. Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this. Test Plan: buck test caffe2/torch/fb/model_transform:signature_translation_test Todo: Verify use in bento Untranslated graph: ``` > graph(%self : __torch__.test_jit.SparseNNWrapper, > %inputs.1 : NamedTuple(dense : Tensor, sparse : Dict(int, Tensor))): > %2 : __torch__.test_jit.SparseNN = prim::GetAttr[name="main_module"](%self) > %4 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23 > return (%4) ``` Translated graph: ``` > graph(%self : __torch__.test_jit.___torch_mangle_1.SparseNNWrapper, > %inputs.1_0 : Tensor, > %inputs.1_1 : Dict(int, Tensor)): > %2 : __torch__.test_jit.___torch_mangle_2.SparseNN = prim::GetAttr[name="main_module"](%self) > %3 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1_0, %inputs.1_1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23 > return (%3) ``` Reviewed By: houseroad Differential Revision: D20313673 fbshipit-source-id: fddd07c9537dc8b6f480a14d697bea10ecc74470
Summary:
Unwrap any tuples (including NamedTuples) in the module forward
function input list to be arglist.
TupleIndex
original toplevel graph inputs
We work from the ScriptModule level instead of the Graph level because:
Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this.
Test Plan:
buck test caffe2/torch/fb/model_transform:signature_translation_test
Todo: Verify use in bento
Untranslated graph:
Translated graph:
Differential Revision: D20313673