Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Allow people to arbitrarily add dispatch keys between DynamicLayer{Front,Back}#843

Merged
zou3519 merged 1 commit intomainfrom
fix_key_move
May 31, 2022
Merged

Allow people to arbitrarily add dispatch keys between DynamicLayer{Front,Back}#843
zou3519 merged 1 commit intomainfrom
fix_key_move

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented May 31, 2022

Fixes #842

The Diagnosis

As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode

Instead, what we're doing today is

JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel

(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:

  • upon entering a dynamic layer (aka Interpreter), we zero out all
    dispatch keys* between FrontMode and BackMode
  • then, the dynamic layer (aka Interpreter) decides to re-enable some
    dispatch keys. For example, JVPInterpreter decides to re-enable the
    autograd keys
  • next, we do a dispatcher call, which will end up hitting one of the
    Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch/pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:

  • [functorch] -> [regular pytorch dispatcher]
  • a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
    section.
  • functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan

Wait for tests

…ont,Back}

Fixes #842

The Diagnosis
=============
As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem
=============

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch/pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution
============

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan
============
Wait for tests
Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for CI but lgtm. Thanks for the detailed explanation (including the Autocast / VmapModeKey interaction, which I didn't notice)!

@zou3519
Copy link
Contributor Author

zou3519 commented May 31, 2022

Failures are unrelated; merging

@zou3519 zou3519 merged commit 28d3ef7 into main May 31, 2022
zou3519 added a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
…amicLayer{Front,Back} (pytorch/functorch#843)

Fixes pytorch/functorch#842

The Diagnosis
=============
As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem
=============

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
pytorch#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution
============

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan
============
Wait for tests
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…amicLayer{Front,Back} (pytorch/functorch#843)

Fixes pytorch/functorch#842

The Diagnosis
=============
As Brian pointed out:

For jvp(sub, ...), the chain of dispatch should be:

```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```

Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)

The Problem
=============

functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).

The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
#77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.

*There is an exception for autocast and vmapmode, described in the next section.

The Solution
============

Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.

This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.

We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.

We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/

Test Plan
============
Wait for tests
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

vmap x jvp bug resulting from DispatchKey::DynamicLayerBack move

3 participants