fix nested grad(functionalize(f)) transforms#76318
fix nested grad(functionalize(f)) transforms#76318bdhirsh wants to merge 6 commits intogh/bdhirsh/217/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 5d02fed (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
| - c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::FuncTorchDynamicLayerBackMode) | ||
| // We still want to copy the Python TLS key - if the inner tensor is a Python tensor, then the wrapper | ||
| // should do a TLS snapshot. | ||
| .remove(c10::DispatchKey::PythonTLSSnapshot); |
There was a problem hiding this comment.
@albanD just curious - is it documented somewhere why the PythonTLSSnapshot key lives inside the range of the functorch keys? I thought it used to be higher priority than everything
There was a problem hiding this comment.
PythonTLSSnapshot caused problems with the AOTAutograd <-> functorch interaction. We then decided that functorch "runs first in the dispatcher", and then the regular PyTorch dispatcher runs after. PythonTLSSnapshot is a part of the "regular PyTorch dispatcher", so we put the key after FuncTorchDynamicLayerFront so that functorch is able to interpose first.
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment) I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch. The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included. For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend. That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere. [ghstack-poisoned]
| // What's happening here? In general, functional wrappers *should* copy the keys from their inner tensor. | ||
| // We need this for the backend use case (LTC/XLA), where any functionalities that the backend uses, | ||
| // like autograd or autocast, need to run *directly* on the wrapper. | ||
| // However, all of the keys corresponding to functorch transforms should not be copied over. | ||
| // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect | ||
| // to participate in the functorch transforms. | ||
| auto keys_to_not_copy = | ||
| c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::FuncTorchDynamicLayerFrontMode) | ||
| - c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::FuncTorchDynamicLayerBackMode) |
There was a problem hiding this comment.
I think you want an explicit denylist with e.g. FuncTorchBatched, FuncTorchGradWrapper, and maybe the sparse keys(?) and the Lazy key (those are also 1:1 with their respective TensorImpl subclasses, right?)
The Autograd keys are also in between DynamicLayerFront and DynamicLayerBack and I'm not sure why you wanted to exclude them
| // However, all of the keys corresponding to functorch transforms should not be copied over. | ||
| // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect | ||
| // to participate in the functorch transforms. |
There was a problem hiding this comment.
To clarify: the problem is when there exists an invariant that if a Tensor has a Dispatch Key, then it must be a specific TensorImpl. This occurs a lot in functorch (BatchedTensorImpl <-> FuncTorchBatched key, TensorWrapper <-> FuncTorchGradWrapper key), but is not a functorch-specific problem, right?
Here are some things that I thought were the case, but I might be wrong about them since it's been a while since I looked into them:
- SparseTensorImpl <-> {Some sparse keys}
- SparseCsrTensorImpl <-> {the csr sparse keys}
- LTCTensorImpl <-> Lazy
There was a problem hiding this comment.
Yep that's right - although in practice it isn't a problem for the backend keys, because we're guaranteed to have unwrapped the FunctionalTensorWrapper before hitting the backend.
I put used FuncTorchDynamicLayerFrontMode - FuncTorchDynamicLayerBackMode because I figured it would be unlikely that we'd add another non-backend, non-functorch key that this logic would apply to, and it lets us avoid hardcoding a new list of keys somewhere.
But I totally forgot that the autograd keys fall in that range. So agreed - I'll create a new keyset corresponding to "tensor impl subclass" keys and use that instead (and at that point I may as well lump in the backend keys, even if they don't directly cause problems in this case).
There was a problem hiding this comment.
Oh I see, I didn't realize FunctionalTensorWrapper doesn't have any problems with the backend keys that correspond 1:1 with tensor impl subclasses.
The new keyset including all "tensor impl subclass keys" sounds reasonable; if you wanted to just continue on the original train of thought we could create a "FuncTorchKeySet" keyset and put all the dispatch keys that begin with "FuncTorch" into it and use that instead
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment) I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch. The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included. For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend. That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere. [ghstack-poisoned]
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment) I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch. The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included. For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend. That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere. [ghstack-poisoned]
| // All of the keys corresponding to functorch transforms should not be copied over. | ||
| // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect | ||
| // to participate in the functorch transforms. | ||
| key_set_ = key_set_ & c10::functorch_transforms_ks; |
There was a problem hiding this comment.
Should this be key_set_ & ~c10::functorch_transforms_ks or am I reading this wrong? Aren't we trying to remove the functorch keys from key_set ?
There was a problem hiding this comment.
yes 😛 thanks. I over-confidentally yolo'd CI last night without testing locally
zou3519
left a comment
There was a problem hiding this comment.
LGTM, minus the change I commented on. Approving for now for developer velocity
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment) I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch. The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included. For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend. That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere. [ghstack-poisoned]
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment) I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch. The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included. For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend. That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere. [ghstack-poisoned]
|
@pytorchbot merge this please |
|
Hey @bdhirsh. |
Summary: Pull Request resolved: #76318 Approved by: https://github.com/zou3519 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/aae7b00f7c7b5e28d1fe8974cd02a538f5be0913 Reviewed By: osalpekar Differential Revision: D35971227 Pulled By: bdhirsh fbshipit-source-id: 30cb520370dc91d3a9be4c59df7cf294b0ecc7f5
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment)
I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch.
The tldr:
FunctionalTensorImplshould copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over all keys. Specifically, any keys that don't know how to handleFunctionalTensorImpldirectly (usually because they expect a specific subclass, likeBatchedTensorImplshouldn't be included.For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the
FunctionalTensorWrapperbefore hitting a backend.That's why I went the route of using
FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmodeto figure out the keys to ignore, instead of hardcoding a list somewhere.Stack from ghstack: