Skip to content

[POC] Don't decompose matmul#91081

Closed
ezyang wants to merge 1 commit intogh/ezyang/1677/basefrom
gh/ezyang/1677/head
Closed

[POC] Don't decompose matmul#91081
ezyang wants to merge 1 commit intogh/ezyang/1677/basefrom
gh/ezyang/1677/head

Conversation

@ezyang
Copy link
Copy Markdown
Contributor

@ezyang ezyang commented Dec 19, 2022

Stack from ghstack (oldest at bottom):

Signed-off-by: Edward Z. Yang ezyang@fb.com

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Dec 19, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91081

Note: Links to docs will display an error until the docs builds have been completed.

❌ 19 Failures

As of commit 71d5202:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Dec 19, 2022
Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: cedfe9c
Pull Request resolved: #91081
@github-actions
Copy link
Copy Markdown
Contributor

This PR needs a label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

from torch._dispatch.python import enable_python_dispatcher, patch_py_impls

with enable_python_dispatcher(), patch_py_impls({
torch.ops.aten.matmul.default: {torch._C.DispatchKey.AutogradCPU: torch._C.DispatchKey.Autograd}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the annoying thing where I can't just say Autograd because the C++ sided CompositeImplicitAutograd overrides me

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm why?
Do we define somewhere how alias keys work with python dispatch?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not defined, it's an emergent property.

Essentially, take the dispatch key registrations from C++, and mash them up with the Python registrations (Python overwriting C++ if they had exactly the same dispatch key), and then compute the dispatch table from the mashup. So if you override Autograd from Python, that will only work if there wasn't a higher priority key overridden from C++.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense!
The current behavior sounds ok then? The user should be well aware if the implementation they provide is meant to override the alias key or the specific key?
I guess what you want here is a KeySet to be able to override all the autograd keys at once?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so one way to hack this is introduce another Autograd alias key that has higher precedence than CompositeImplicitAutograd lol. It's just kind of a pain haha

return self._schema.name.split("::")[0]

def decompose(self, *args, **kwargs):
return NotImplemented
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need an api for this...

@SherlockNoMad
Copy link
Copy Markdown
Contributor

lol... Matmul now has three execution path

  1. act like a CompositExplictAutograd, by patching py_impl
  2. act like a CompositImplicitAutograd, via python impl
  3. act like a CompositImplicitAutograd, via C++ impl

@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Feb 18, 2023
ezyang added a commit that referenced this pull request Feb 21, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 21, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 22, 2023
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes:

* torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test
* I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going
* I add a direct Python transcription of the reshape composite adapted from #84584 . I cannot use the PrimTorch composite as it has problems when I register it pre-autograd. It has the same implementation as regular reshape, but at the beginning there is one more test for trivial reshapes, which is sufficient for the matmul example.
* I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
@github-actions github-actions bot closed this Mar 20, 2023
@facebook-github-bot facebook-github-bot deleted the gh/ezyang/1677/head branch June 8, 2023 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants