Conversation
Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
This PR needs a labelIf your changes are user facing and intended to be a part of release notes, please use a label starting with If not, please add the For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work. |
| from torch._dispatch.python import enable_python_dispatcher, patch_py_impls | ||
|
|
||
| with enable_python_dispatcher(), patch_py_impls({ | ||
| torch.ops.aten.matmul.default: {torch._C.DispatchKey.AutogradCPU: torch._C.DispatchKey.Autograd} |
There was a problem hiding this comment.
this is the annoying thing where I can't just say Autograd because the C++ sided CompositeImplicitAutograd overrides me
There was a problem hiding this comment.
Hmmm why?
Do we define somewhere how alias keys work with python dispatch?
There was a problem hiding this comment.
No, it's not defined, it's an emergent property.
Essentially, take the dispatch key registrations from C++, and mash them up with the Python registrations (Python overwriting C++ if they had exactly the same dispatch key), and then compute the dispatch table from the mashup. So if you override Autograd from Python, that will only work if there wasn't a higher priority key overridden from C++.
There was a problem hiding this comment.
Ok makes sense!
The current behavior sounds ok then? The user should be well aware if the implementation they provide is meant to override the alias key or the specific key?
I guess what you want here is a KeySet to be able to override all the autograd keys at once?
There was a problem hiding this comment.
Yeah, so one way to hack this is introduce another Autograd alias key that has higher precedence than CompositeImplicitAutograd lol. It's just kind of a pain haha
| return self._schema.name.split("::")[0] | ||
|
|
||
| def decompose(self, *args, **kwargs): | ||
| return NotImplemented |
There was a problem hiding this comment.
need an api for this...
|
lol... Matmul now has three execution path
|
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I switch `reshape` to use the Python implementation, which is easier to debug than the C++ one. Previously we couldn't easily do this as it was composite, but now we can with Python dispatcher. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
This PR gets `reflect @ R @ reflect` working, where R has unbacked batch size. This pattern occurred in CrystalDPR. The billing of changes: * torch.broadcast_shapes avoids guarding on unbacked SymInts when testing for broadcastable dims. I extracted this to #95217 for separate review; it's repeated in this PR as it is necessary for the E2E test * I disable matrix multiply folding when there is an unbacked SymInt on any input. Folding is strictly a performance optimization and can be omitted. Also, I believe export would prefer to get matmul (rather than bmm/etc), so we should eventually actually get #91081 going * I add a direct Python transcription of the reshape composite adapted from #84584 . I cannot use the PrimTorch composite as it has problems when I register it pre-autograd. It has the same implementation as regular reshape, but at the beginning there is one more test for trivial reshapes, which is sufficient for the matmul example. * I hand-write a meta function for expand, rather than using the PrimTorch decomposition. I couldn't really figure out how to make the PrimTorch decomposition guard free, but with the hand-written meta it is clear where the divergence lies: we cannot easily choose the correct stride for the unbacked dim, as we need to know whether or not the size is one (in which case we give the predicted stride) versus non-one (in which case we MUST give zero.) In composability sync, we agreed that changes to striding behavior are fair game with unbacked SymInts, so I just unconditionally give these zero stride. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Signed-off-by: Edward Z. Yang ezyang@fb.com