Rewrite torch.broadcast_shapes to be unbacked SymInt friendly#95217
Closed
ezyang wants to merge 2 commits intogh/ezyang/1836/basefrom
Closed
Rewrite torch.broadcast_shapes to be unbacked SymInt friendly#95217ezyang wants to merge 2 commits intogh/ezyang/1836/basefrom
ezyang wants to merge 2 commits intogh/ezyang/1836/basefrom
Conversation
This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9 but I have to do this everywhere we have a broadcasting implementation. If you want me to spend some BE time deduping these, please holler. Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
ezyang
added a commit
that referenced
this pull request
Feb 21, 2023
This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9 but I have to do this everywhere we have a broadcasting implementation. If you want me to spend some BE time deduping these, please holler. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 13eef87 Pull Request resolved: #95217
albanD
approved these changes
Feb 21, 2023
Collaborator
albanD
left a comment
There was a problem hiding this comment.
That is quite a bit more intrusive than I would like it to be...
Also it has been quite a while without tests for these?
But the change is ok.
Contributor
Author
|
The E2E is in #95218 but I can make a unit test for this. Actually, I can't even land this PR yet because the PR it was based on got reverted.
Let's talk about it. There seem to be two problems:
WDYT? |
ezyang
added a commit
that referenced
this pull request
Feb 21, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
ezyang
added a commit
that referenced
this pull request
Feb 21, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
…dly" This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9 but I have to do this everywhere we have a broadcasting implementation. If you want me to spend some BE time deduping these, please holler. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
ezyang
added a commit
that referenced
this pull request
Feb 22, 2023
This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9 but I have to do this everywhere we have a broadcasting implementation. If you want me to spend some BE time deduping these, please holler. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 7a1119e Pull Request resolved: #95217
ezyang
added a commit
that referenced
this pull request
Feb 22, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I add a direct Python transcription of the reshape composite adapted from #84584 . I cannot use the PrimTorch composite as it has problems when I register it pre-autograd. It has the same implementation as regular reshape, but at the beginning there is one more test for trivial reshapes, which is sufficient for the matmul example. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
This is similar to what I did in https://github.com/pytorch/pytorch/pull/94790/files#diff-39e82af71afdadbc56a4ecf552ec668ddcda794f8ea3ec41edaef23456fd56e9
but I have to do this everywhere we have a broadcasting implementation.
If you want me to spend some BE time deduping these, please holler.
Signed-off-by: Edward Z. Yang ezyang@meta.com