This repository was archived by the owner on Aug 21, 2025. It is now read-only.
Allow people to arbitrarily add dispatch keys between DynamicLayer{Front,Back}#843
Merged
Allow people to arbitrarily add dispatch keys between DynamicLayer{Front,Back}#843
Conversation
…ont,Back} Fixes #842 The Diagnosis ============= As Brian pointed out: For jvp(sub, ...), the chain of dispatch should be: ``` DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode ``` Instead, what we're doing today is ``` JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel ``` (the zero_tensor kernel errors out, because the inputs are BatchedTensorImpl objects) The Problem ============= functorch's behavior on dispatch keys between DynamicLayerFrontMode and DynamicLayerBack mode should be: - upon entering a dynamic layer (aka Interpreter), we zero out all dispatch keys* between FrontMode and BackMode - then, the dynamic layer (aka Interpreter) decides to re-enable some dispatch keys. For example, JVPInterpreter decides to re-enable the autograd keys - next, we do a dispatcher call, which will end up hitting one of the Autograd keys (in the JVPInterpreter case). The bug is that functorch has a hardcoded list of dispatch keys that it zeros out. This list does not include ZeroTensor, because before pytorch/pytorch#77132, the ZeroTensor key was not between DynamicLayer{Front,Back}Mode. *There is an exception for autocast and vmapmode, described in the next section. The Solution ============ Change functorch to programmatically zero out keys between DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of Autocast and VmapMode. This means that in the future, if someone adds a dispatch key between DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be handling it "correctly": the model for dispatch is: - [functorch] -> [regular pytorch dispatcher] - a key like ZeroTensor gets handled in the [regular pytorch dispatcher] section. - functorch transforms get handled in the [functorch] section. We do not change the autocast <-> functorch interaction in this PR (i.e. functorch does not zero it out) because I'm not sure what the correct thing to do here is. We do not change how kVmapMode works because... it needs to be active to ban random operations in transforms later down the line :/ Test Plan ============ Wait for tests
bdhirsh
approved these changes
May 31, 2022
Contributor
bdhirsh
left a comment
There was a problem hiding this comment.
Waiting for CI but lgtm. Thanks for the detailed explanation (including the Autocast / VmapModeKey interaction, which I didn't notice)!
Contributor
Author
|
Failures are unrelated; merging |
zou3519
added a commit
to zou3519/pytorch
that referenced
this pull request
Jul 20, 2022
…amicLayer{Front,Back} (pytorch/functorch#843)
Fixes pytorch/functorch#842
The Diagnosis
=============
As Brian pointed out:
For jvp(sub, ...), the chain of dispatch should be:
```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```
Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)
The Problem
=============
functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).
The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.
*There is an exception for autocast and vmapmode, described in the next section.
The Solution
============
Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.
This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.
We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.
We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/
Test Plan
============
Wait for tests
bigfootjon
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Jul 21, 2022
…amicLayer{Front,Back} (pytorch/functorch#843)
Fixes pytorch/functorch#842
The Diagnosis
=============
As Brian pointed out:
For jvp(sub, ...), the chain of dispatch should be:
```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```
Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)
The Problem
=============
functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).
The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.
*There is an exception for autocast and vmapmode, described in the next section.
The Solution
============
Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.
This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.
We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.
We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/
Test Plan
============
Wait for tests
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #842
The Diagnosis
As Brian pointed out:
For jvp(sub, ...), the chain of dispatch should be:
Instead, what we're doing today is
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)
The Problem
functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
dispatch keys* between FrontMode and BackMode
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
Autograd keys (in the JVPInterpreter case).
The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch/pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.
*There is an exception for autocast and vmapmode, described in the next section.
The Solution
Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.
This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
section.
We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.
We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/
Test Plan
Wait for tests