Get matrix multiply with unbacked SymInt working#95218
Closed
ezyang wants to merge 6 commits intogh/ezyang/1837/basefrom
Closed
Get matrix multiply with unbacked SymInt working#95218ezyang wants to merge 6 commits intogh/ezyang/1837/basefrom
ezyang wants to merge 6 commits intogh/ezyang/1837/basefrom
Conversation
Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95218
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0860fd3: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced Feb 21, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This was referenced Feb 21, 2023
ezyang
commented
Feb 21, 2023
torch/_refs/__init__.py
Outdated
| # NOTE: shape is a vararg because Tensor.reshape can be called with as | ||
| # Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call | ||
| # torch.reshape doesn't support unpacked shapes | ||
| @aten.reshape.default.py_impl(DispatchKey.CompositeImplicitAutograd) |
Contributor
Author
There was a problem hiding this comment.
This appears to be deeply problematic. I'll probably figure out another way to do this.
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I add a direct Python transcription of the reshape composite adapted from #84584 . I cannot use the PrimTorch composite as it has problems when I register it pre-autograd. It has the same implementation as regular reshape, but at the beginning there is one more test for trivial reshapes, which is sufficient for the matmul example. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
This PR gets
reflect @ R @ reflectworking, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:Signed-off-by: Edward Z. Yang ezyang@meta.com