Skip to content

[jit] Fuse tensor-scalar ops when scalar is constant #10511

Closed
zou3519 wants to merge 8 commits intopytorch:masterfrom
zou3519:pytorch-fusescalar
Closed

[jit] Fuse tensor-scalar ops when scalar is constant #10511
zou3519 wants to merge 8 commits intopytorch:masterfrom
zou3519:pytorch-fusescalar

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Aug 14, 2018

This is on the way to resolving #9940.

Fixes #10501

This PR modifies graph fuser to fuse operations that have constant
scalar arguments. These constant scalar arguments are directly inlined
into the kernel body.

The context for this is that LSTM backward (in particular, sigmoid
backward) has many add(x, 1.) operations. This PR should be sufficient for
LSTM backward to get fused by the graph fuser.

cc @apaszke @zdevito

@zou3519 zou3519 added the oncall: jit Add this issue/PR to JIT oncall triage queue label Aug 14, 2018
test/test_jit.py Outdated

This comment was marked as off-topic.

@zou3519 zou3519 force-pushed the pytorch-fusescalar branch from 2267979 to a235f9d Compare August 14, 2018 19:01
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. I have some small comments. I am surprised that we didn't need to modify the fusion_compiler, did we already have code to emit constants?

This comment was marked as off-topic.

This comment was marked as off-topic.

@zou3519
Copy link
Contributor Author

zou3519 commented Aug 15, 2018

@zdevito Yes there was already code to emit constants inlined into the body of the FusionGroup. In particular, the graph_fuser inlines the alpha argument in add(Tensor, Tensor, Scalar alpha) into the FusionGroup if it is constant. I took this logic and expanded it to work with all constant number arguments.

@zou3519 zou3519 force-pushed the pytorch-fusescalar branch from 389d846 to 0616b4a Compare August 15, 2018 16:37

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, but I'd like to remove the hacky scalar checks that don't even try to see what overloads do we handle. Please use the matching syntax, or we'll end up with a lot of bugs again.

This comment was marked as off-topic.

test/test_jit.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great now! Some minor comments.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519 zou3519 force-pushed the pytorch-fusescalar branch 2 times, most recently from 221c95f to c79b24a Compare August 17, 2018 17:07
This is on the way to resolving pytorch#9940.

This PR modifies graph fuser to fuse operations that have constant
scalar arguments. These constant scalar arguments are directly inlined
into the kernel body.

The context for this is that LSTM backward (in particular, sigmoid
backward) has many add(x, 1.) operations. This PR should be sufficient for
LSTM backward to get fused by the graph fuser.
- Use WithInsertPoint instead of insertAfter
- Make the compatible devices logic more explicit by adding a Device
  struct and a DeviceType enum. The possible Devices are
  `Unknown | AnyDevice | CPU | CUDA i`
- Use ->matches instead of hacky allowing-numbers-in-type-checks
- Rewrite bool compatibleDevices(Node * consumer, Value * producer) to
  be more readable.
@zou3519 zou3519 force-pushed the pytorch-fusescalar branch from c79b24a to 9d943a6 Compare August 17, 2018 17:39
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[jit] Fuse sub(Tensor, Tensor, Scalar) nodes

5 participants