Add Autocast Support for FakeTensors / use fake device dispatch keys#82449
Add Autocast Support for FakeTensors / use fake device dispatch keys#82449eellison wants to merge 8 commits intogh/eellison/301/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit f81d77c (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
…patch keys" [ghstack-poisoned]
…patch keys" From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. [ghstack-poisoned]
| auto k = key_set.highestBackendKey(); | ||
| local_keyset.included_ = local_keyset.included_.remove_backend(k); | ||
| c10::impl::_force_tls_local_dispatch_key_set(local_keyset); | ||
| }); |
There was a problem hiding this comment.
feels like it would be better to bind one of the local TLS RAII objects. It's pretty quick and easy
There was a problem hiding this comment.
I'm trying to fix the functionalization <> dynamic shape issues and I ended up adding something pretty similar here: _change_backend_component_keys.. If this looks good to you, then we can either stamp those changes into this PR or fix this code up after my PR lands.
There was a problem hiding this comment.
If you don't mind fixing up this code after my PR lands, that would be great. Thanks brian.
| # the call here | ||
| # because it doesn't go through the dispatcher, we run into errors | ||
| # when attempting to compute an output in meta, so | ||
| # we compute the real tensor then convert to meta |
There was a problem hiding this comment.
But this means that new is no longer memory efficient, is that right?
There was a problem hiding this comment.
Yea, this is sort of a deprecated API / doesn't go through the dispatcher, so it would make sense that it might be fragile to changes. Similar to fallbacks, seemed okay to allocate a real tensor temporarily.
There was a problem hiding this comment.
OK. It's just odd that although we have explicit handling for it, we aren't able to just go ahead and write the "correct" fake rule for it. But I suppose it might be annoying to do.
| KERNEL_CPU(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) | ||
| KERNEL_CPU(ADD_NS(conv2d), "conv2d.padding", Tensor (const Tensor&, const Tensor&, const c10::optional<Tensor>&, IntArrayRef, c10::string_view, IntArrayRef, int64_t groups), lower_precision_fp) | ||
| KERNEL_CPU(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp) | ||
| KERNEL_CPU(ADD_NS(conv3d), "conv3d.padding", Tensor (const Tensor&, const Tensor&, const c10::optional<Tensor>&, IntArrayRef, c10::string_view, IntArrayRef, int64_t groups), lower_precision_fp) |
There was a problem hiding this comment.
Do you have a more detailed explanation about these? Is a bit confusing.
There was a problem hiding this comment.
aten::conv2d.padding redispatches to aten::convolution, which doesn't have an autocast registered for it, _convolution does. Its only in the CPU kernel by invoking _convolution that autocast gets applied. Since the meta kernel for aten::convolution wont call _convolution without the changes no dispatching gets applied.
So I mirrored the existing pattern for conv1d and added it to the registration list.
There was a problem hiding this comment.
oof. I guess if we turned off autocast after we got past the autocast layer this would prevent that, but we don't really want to pay for it. :/ There may be other latent bugs like this.
There was a problem hiding this comment.
At the risk of micro-optimizing: One thing we could maybe do if we’re worried about the boxing/unboxing cost of the fallback (although we’d still have to pay for the redispatch + TLS): since we don’t actually care about the operator’s arguments, we could teach the dispatcher to not box/unbox anything if the boxed fallback that you register doesn’t take in an argument stack (indicating that it doesn’t use the arguments)
…patch keys" From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. See: #81608 [ghstack-poisoned]
…patch keys" From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. See: #81608 [ghstack-poisoned]
…patch keys" From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. See: #81608 [ghstack-poisoned]
…patch keys" From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. See: #81608 [ghstack-poisoned]
…patch keys" From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. See: #81608 [ghstack-poisoned]
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @eellison. |
…82449) (#82449) Summary: From PR: ``` Note: [Fake Tensor Dispatch Keys] In order to model the behavior of device-specific autocast and autograd logic, we update the dispatch keys of FakeTensors to reflect their fake device. This includes the BackendComponent (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent related Autocast and Autograd keys. __torch__dispatch__ sits below Autocast and Autograd, and is only invoked when we are at the kernel for the BackendComponent. Then, we add Meta to the thread-local dispatch include set to hit the meta kernel instead of the kernel of the BackendComponent for the fake device. ``` Also adds the `conv1/2/3d.padding` operators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge. See: #81608 Pull Request resolved: #82449 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/642aed8b99bb35ab7029e7e839350a6d06464882 Reviewed By: kit1980 Differential Revision: D38330007 Pulled By: eellison fbshipit-source-id: db1865a5a59c81ca66533de091f7ae1865efacd7
Stack from ghstack (oldest at bottom):
From PR:
Also adds the
conv1/2/3d.paddingoperators to the Autocast rule set. Without that fix, the FakeTensor dtype would diverge.See: #81608