[jit] Fuse tensor-scalar ops when scalar is constant #10511
[jit] Fuse tensor-scalar ops when scalar is constant #10511zou3519 wants to merge 8 commits intopytorch:masterfrom
Conversation
test/test_jit.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
2267979 to
a235f9d
Compare
zdevito
left a comment
There was a problem hiding this comment.
This looks good. I have some small comments. I am surprised that we didn't need to modify the fusion_compiler, did we already have code to emit constants?
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@zdevito Yes there was already code to emit constants inlined into the body of the FusionGroup. In particular, the graph_fuser inlines the |
389d846 to
0616b4a
Compare
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Mostly LGTM, but I'd like to remove the hacky scalar checks that don't even try to see what overloads do we handle. Please use the matching syntax, or we'll end up with a lot of bugs again.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_jit.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/register_prim_ops.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Looks great now! Some minor comments.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
221c95f to
c79b24a
Compare
This is on the way to resolving pytorch#9940. This PR modifies graph fuser to fuse operations that have constant scalar arguments. These constant scalar arguments are directly inlined into the kernel body. The context for this is that LSTM backward (in particular, sigmoid backward) has many add(x, 1.) operations. This PR should be sufficient for LSTM backward to get fused by the graph fuser.
- Use WithInsertPoint instead of insertAfter - Make the compatible devices logic more explicit by adding a Device struct and a DeviceType enum. The possible Devices are `Unknown | AnyDevice | CPU | CUDA i`
- Use ->matches instead of hacky allowing-numbers-in-type-checks - Rewrite bool compatibleDevices(Node * consumer, Value * producer) to be more readable.
c79b24a to
9d943a6
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This is on the way to resolving #9940.
Fixes #10501
This PR modifies graph fuser to fuse operations that have constant
scalar arguments. These constant scalar arguments are directly inlined
into the kernel body.
The context for this is that LSTM backward (in particular, sigmoid
backward) has many add(x, 1.) operations. This PR should be sufficient for
LSTM backward to get fused by the graph fuser.
cc @apaszke @zdevito