Skip to content

Rewrite torch.broadcast_shapes to be unbacked SymInt friendly#95217

Closed
ezyang wants to merge 2 commits intogh/ezyang/1836/basefrom
gh/ezyang/1836/head
Closed

Rewrite torch.broadcast_shapes to be unbacked SymInt friendly#95217
ezyang wants to merge 2 commits intogh/ezyang/1836/basefrom
gh/ezyang/1836/head

Conversation

@ezyang
Copy link
Copy Markdown
Contributor

@ezyang ezyang commented Feb 21, 2023

Stack from ghstack (oldest at bottom):

This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9
but I have to do this everywhere we have a broadcasting implementation.

If you want me to spend some BE time deduping these, please holler.

Signed-off-by: Edward Z. Yang ezyang@meta.com

This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9
but I have to do this everywhere we have a broadcasting implementation.

If you want me to spend some BE time deduping these, please holler.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95217

Note: Links to docs will display an error until the docs builds have been completed.

❌ 67 Failures

As of commit 4cb9b24:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Feb 21, 2023
This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9
but I have to do this everywhere we have a broadcasting implementation.

If you want me to spend some BE time deduping these, please holler.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 13eef87
Pull Request resolved: #95217
@ezyang ezyang added ciflow/trunk Trigger trunk jobs on your pull request release notes: composability release notes category topic: not user facing topic category labels Feb 21, 2023
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is quite a bit more intrusive than I would like it to be...
Also it has been quite a while without tests for these?

But the change is ok.

@ezyang
Copy link
Copy Markdown
Contributor Author

ezyang commented Feb 21, 2023

The E2E is in #95218 but I can make a unit test for this. Actually, I can't even land this PR yet because the PR it was based on got reverted.

That is quite a bit more intrusive than I would like it to be...

Let's talk about it. There seem to be two problems:

  1. I have to keep doing this change in lots of places. I could marginally make this better by trying to dedupe broadcast implementations.
  2. The change itself is quite invasive. I'm not sure of a better way to write out this change. I could factor out the "broadcast two single sizes" into its own helper function, but it'd still have roughly the same shape internally.

WDYT?

ezyang added a commit that referenced this pull request Feb 21, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 21, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
…dly"

This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9
but I have to do this everywhere we have a broadcasting implementation.

If you want me to spend some BE time deduping these, please holler.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 22, 2023
This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9
but I have to do this everywhere we have a broadcasting implementation.

If you want me to spend some BE time deduping these, please holler.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 7a1119e
Pull Request resolved: #95217
ezyang added a commit that referenced this pull request Feb 22, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I add a direct Python transcription of the reshape composite adapted from #84584 . I cannot use the PrimTorch composite as it has problems when I register it pre-autograd. It has the same implementation as regular reshape, but at the beginning there is one more test for trivial reshapes, which is sufficient for the matmul example.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@ezyang ezyang closed this Feb 23, 2023
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1836/head branch June 8, 2023 16:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request release notes: composability release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants