Skip to content

Get matrix multiply with unbacked SymInt working#95218

Closed
ezyang wants to merge 6 commits intogh/ezyang/1837/basefrom
gh/ezyang/1837/head
Closed

Get matrix multiply with unbacked SymInt working#95218
ezyang wants to merge 6 commits intogh/ezyang/1837/basefrom
gh/ezyang/1837/head

Conversation

@ezyang
Copy link
Copy Markdown
Contributor

@ezyang ezyang commented Feb 21, 2023

Stack from ghstack (oldest at bottom):

This PR gets reflect @ R @ reflect working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

  • torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to Rewrite torch.broadcast_shapes to be unbacked SymInt friendly #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
  • I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get [POC] Don't decompose matmul #91081 going
  • I add a direct Python transcription of the reshape composite adapted from reshape in python #84584 . I cannot use the PrimTorch composite as it has problems when I register it pre-autograd. It has the same implementation as regular reshape, but at the beginning there is one more test for trivial reshapes, which is sufficient for the matmul example.
  • I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang ezyang@meta.com

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95218

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0860fd3:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@ezyang ezyang added release notes: composability release notes category topic: not user facing topic category labels Feb 21, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 21, 2023
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: e230b4f
Pull Request resolved: #95218
@ezyang ezyang requested review from ngimel and suo February 21, 2023 15:39
# NOTE: shape is a vararg because Tensor.reshape can be called with as
# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
# torch.reshape doesn't support unpacked shapes
@aten.reshape.default.py_impl(DispatchKey.CompositeImplicitAutograd)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be deeply problematic. I'll probably figure out another way to do this.

This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@ezyang ezyang requested a review from mruberry as a code owner February 21, 2023 17:15
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I add a direct Python transcription of the reshape composite adapted from #84584 . I cannot use the PrimTorch composite as it has problems when I register it pre-autograd. It has the same implementation as regular reshape, but at the beginning there is one more test for trivial reshapes, which is sufficient for the matmul example.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 22, 2023
@ezyang ezyang closed this Feb 23, 2023
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1837/head branch June 8, 2023 16:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request release notes: composability release notes category release notes: fx release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant