Skip to content

[model transform] tuple to arglist jit pass#36093

Closed
ajyu wants to merge 1 commit intopytorch:masterfrom
ajyu:export-D20313673
Closed

[model transform] tuple to arglist jit pass#36093
ajyu wants to merge 1 commit intopytorch:masterfrom
ajyu:export-D20313673

Conversation

@ajyu
Copy link
Copy Markdown
Contributor

@ajyu ajyu commented Apr 6, 2020

Summary:
Unwrap any tuples (including NamedTuples) in the module forward
function input list to be arglist.

  1. Supports multiple tuple inputs, and traces their use through CallMethods and
    TupleIndex
  2. Does not unwrap inner use of other tuples that did not show up in the
    original toplevel graph inputs

We work from the ScriptModule level instead of the Graph level because:

  1. If the ScriptModule was previously called with the original set of inputs, the GraphExecutor caches the ExecutionPlan (specifically, ArgumentSpecCreator is derived from the Graph and type check the inputs passed in)
  2. Since we are changing this graph's inputs, we clone the module and clear the GraphExecutor.

Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this.

Test Plan:
buck test caffe2/torch/fb/model_transform:signature_translation_test
Todo: Verify use in bento

Untranslated graph:

> graph(%self : __torch__.test_jit.SparseNNWrapper,
>       %inputs.1 : NamedTuple(dense : Tensor, sparse : Dict(int, Tensor))):
>   %2 : __torch__.test_jit.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %4 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%4)

Translated graph:

> graph(%self : __torch__.test_jit.___torch_mangle_1.SparseNNWrapper,
>       %inputs.1_0 : Tensor,
>       %inputs.1_1 : Dict(int, Tensor)):
>   %2 : __torch__.test_jit.___torch_mangle_2.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %3 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1_0, %inputs.1_1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%3)

Differential Revision: D20313673

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D20313673

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 6, 2020
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Apr 6, 2020

💊 CircleCI build failures summary and remediations

As of commit a5b6d65 (more details on the Dr. CI page):


None of the build failures appear to be your fault 💚


  • 1/1 broken upstream at merge base 14ce500 on Apr 09 from 2:45pm to 5:52pm PDT (10 commits; 477f1c0 - c5662dd)

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase FETCH_HEAD
    

    Check out the recency history of this "viable master" tracking branch.


🚧 1 upstream failure:

These were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 9 times.

Summary:
Pull Request resolved: pytorch#36093

Unwrap any tuples (including NamedTuples) in the module forward
function input list to be arglist.
1. Supports multiple tuple inputs, and traces their use through CallMethods and
TupleIndex
2. Does not unwrap inner use of other tuples that did not show up in the
original toplevel graph inputs

We work from the ScriptModule level instead of the Graph level because:
1. If the ScriptModule was previously called with the original set of inputs, the GraphExecutor caches the ExecutionPlan (specifically, ArgumentSpecCreator is derived from the Graph and type check the inputs passed in)
2. Since we are changing this graph's inputs, we clone the module and clear the GraphExecutor.

Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this.

Test Plan:
buck test caffe2/torch/fb/model_transform:signature_translation_test
Todo: Verify use in bento

Untranslated graph:
```
> graph(%self : __torch__.test_jit.SparseNNWrapper,
>       %inputs.1 : NamedTuple(dense : Tensor, sparse : Dict(int, Tensor))):
>   %2 : __torch__.test_jit.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %4 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%4)
```

Translated graph:
```
> graph(%self : __torch__.test_jit.___torch_mangle_1.SparseNNWrapper,
>       %inputs.1_0 : Tensor,
>       %inputs.1_1 : Dict(int, Tensor)):
>   %2 : __torch__.test_jit.___torch_mangle_2.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %3 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1_0, %inputs.1_1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%3)
```

Reviewed By: houseroad

Differential Revision: D20313673

fbshipit-source-id: cd99a7bf9181095a07d97a574993a91ca21da099
@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request was exported from Phabricator. Differential Revision: D20313673

@ajyu ajyu force-pushed the export-D20313673 branch from a48cd5f to a5b6d65 Compare April 9, 2020 22:40
Copy link
Copy Markdown
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed test is unrelated

@facebook-github-bot
Copy link
Copy Markdown
Contributor

This pull request has been merged in aac36a8.

ashishfarmer pushed a commit to ashishfarmer/pytorch that referenced this pull request Apr 13, 2020
Summary:
Pull Request resolved: pytorch#36093

Unwrap any tuples (including NamedTuples) in the module forward
function input list to be arglist.
1. Supports multiple tuple inputs, and traces their use through CallMethods and
TupleIndex
2. Does not unwrap inner use of other tuples that did not show up in the
original toplevel graph inputs

We work from the ScriptModule level instead of the Graph level because:
1. If the ScriptModule was previously called with the original set of inputs, the GraphExecutor caches the ExecutionPlan (specifically, ArgumentSpecCreator is derived from the Graph and type check the inputs passed in)
2. Since we are changing this graph's inputs, we clone the module and clear the GraphExecutor.

Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this.

Test Plan:
buck test caffe2/torch/fb/model_transform:signature_translation_test
Todo: Verify use in bento

Untranslated graph:
```
> graph(%self : __torch__.test_jit.SparseNNWrapper,
>       %inputs.1 : NamedTuple(dense : Tensor, sparse : Dict(int, Tensor))):
>   %2 : __torch__.test_jit.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %4 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%4)
```

Translated graph:
```
> graph(%self : __torch__.test_jit.___torch_mangle_1.SparseNNWrapper,
>       %inputs.1_0 : Tensor,
>       %inputs.1_1 : Dict(int, Tensor)):
>   %2 : __torch__.test_jit.___torch_mangle_2.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %3 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1_0, %inputs.1_1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%3)
```

Reviewed By: houseroad

Differential Revision: D20313673

fbshipit-source-id: fddd07c9537dc8b6f480a14d697bea10ecc74470
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#36093

Unwrap any tuples (including NamedTuples) in the module forward
function input list to be arglist.
1. Supports multiple tuple inputs, and traces their use through CallMethods and
TupleIndex
2. Does not unwrap inner use of other tuples that did not show up in the
original toplevel graph inputs

We work from the ScriptModule level instead of the Graph level because:
1. If the ScriptModule was previously called with the original set of inputs, the GraphExecutor caches the ExecutionPlan (specifically, ArgumentSpecCreator is derived from the Graph and type check the inputs passed in)
2. Since we are changing this graph's inputs, we clone the module and clear the GraphExecutor.

Since we work from ScriptModule level, we cannot take advantage of jit level syntactic sugar like run_pass(), so I jit exposed this as a cpp extension. Let me know if there are other ideas about this.

Test Plan:
buck test caffe2/torch/fb/model_transform:signature_translation_test
Todo: Verify use in bento

Untranslated graph:
```
> graph(%self : __torch__.test_jit.SparseNNWrapper,
>       %inputs.1 : NamedTuple(dense : Tensor, sparse : Dict(int, Tensor))):
>   %2 : __torch__.test_jit.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %4 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%4)
```

Translated graph:
```
> graph(%self : __torch__.test_jit.___torch_mangle_1.SparseNNWrapper,
>       %inputs.1_0 : Tensor,
>       %inputs.1_1 : Dict(int, Tensor)):
>   %2 : __torch__.test_jit.___torch_mangle_2.SparseNN = prim::GetAttr[name="main_module"](%self)
>   %3 : Tensor = prim::CallMethod[name="forward"](%2, %inputs.1_0, %inputs.1_1) # /data/users/ansha/fbsource/fbcode/buck-out/dev/gen/caffe2/test/jit#binary,link-tree/test_jit.py:12141:23
>   return (%3)
```

Reviewed By: houseroad

Differential Revision: D20313673

fbshipit-source-id: fddd07c9537dc8b6f480a14d697bea10ecc74470
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fb-exported Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants