Skip to content

fix nested grad(functionalize(f)) transforms#76318

Closed
bdhirsh wants to merge 6 commits intogh/bdhirsh/217/basefrom
gh/bdhirsh/217/head
Closed

fix nested grad(functionalize(f)) transforms#76318
bdhirsh wants to merge 6 commits intogh/bdhirsh/217/basefrom
gh/bdhirsh/217/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Apr 25, 2022

More discussion / context can be found at the comment here: pytorch/functorch#732 (comment)

I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch.

The tldr: FunctionalTensorImpl should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over all keys. Specifically, any keys that don't know how to handle FunctionalTensorImpl directly (usually because they expect a specific subclass, like BatchedTensorImpl shouldn't be included.

For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the FunctionalTensorWrapper before hitting a backend.

That's why I went the route of using FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode to figure out the keys to ignore, instead of hardcoding a list somewhere.

Stack from ghstack:

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 25, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 5d02fed (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-04-27T02:37:36.0239041Z FAIL [0.088s]: test_exception_single (__main__.SpawnTest)
2022-04-27T02:37:36.0135997Z     test_exception_single succeeded - num_retries_left: 0
2022-04-27T02:37:36.0143942Z   test_first_argument_index (__main__.SpawnTest) ... skip: Test is disabled because an issue exists disabling it: https://github.com/pytorch/pytorch/issues/73266 for platform(s) linux. If you're seeing this on your local machine and would like to enable this test, please make sure IN_CI is not set and you are not using the flag --import-disabled-tests. (0.001s)
2022-04-27T02:37:36.0212175Z   test_signal_raises (__main__.SpawnTest) ... ok (0.007s)
2022-04-27T02:37:36.0219075Z   test_success (__main__.SpawnTest) ... skip: Test is disabled because an issue exists disabling it: https://github.com/pytorch/pytorch/issues/72298 for allplatform(s) . If you're seeing this on your local machine and would like to enable this test, please make sure IN_CI is not set and you are not using the flag --import-disabled-tests. (0.001s)
2022-04-27T02:37:36.0223811Z   test_success_first_then_exception (__main__.SpawnTest) ... skip: Test is disabled because an issue exists disabling it: https://github.com/pytorch/pytorch/issues/72625 for platform(s) linux. If you're seeing this on your local machine and would like to enable this test, please make sure IN_CI is not set and you are not using the flag --import-disabled-tests. (0.000s)
2022-04-27T02:37:36.0228556Z   test_success_non_blocking (__main__.SpawnTest) ... skip: Test is disabled because an issue exists disabling it: https://github.com/pytorch/pytorch/issues/72926 for platform(s) linux. If you're seeing this on your local machine and would like to enable this test, please make sure IN_CI is not set and you are not using the flag --import-disabled-tests. (0.000s)
2022-04-27T02:37:36.0232806Z   test_terminate_exit (__main__.SpawnTest) ... skip: Test is disabled because an issue exists disabling it: https://github.com/pytorch/pytorch/issues/72624 for platform(s) linux. If you're seeing this on your local machine and would like to enable this test, please make sure IN_CI is not set and you are not using the flag --import-disabled-tests. (0.000s)
2022-04-27T02:37:36.0237562Z   test_terminate_signal (__main__.SpawnTest) ... skip: Test is disabled because an issue exists disabling it: https://github.com/pytorch/pytorch/issues/73341 for platform(s) linux. If you're seeing this on your local machine and would like to enable this test, please make sure IN_CI is not set and you are not using the flag --import-disabled-tests. (0.000s)
2022-04-27T02:37:36.0238381Z 
2022-04-27T02:37:36.0238532Z ======================================================================
2022-04-27T02:37:36.0239041Z FAIL [0.088s]: test_exception_single (__main__.SpawnTest)
2022-04-27T02:37:36.0239572Z ----------------------------------------------------------------------
2022-04-27T02:37:36.0239959Z torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGBUS
2022-04-27T02:37:36.0240186Z 
2022-04-27T02:37:36.0240323Z During handling of the above exception, another exception occurred:
2022-04-27T02:37:36.0240574Z 
2022-04-27T02:37:36.0240761Z Traceback (most recent call last):
2022-04-27T02:37:36.0241249Z   File "test_multiprocessing_spawn.py", line 117, in test_exception_single
2022-04-27T02:37:36.0241825Z     mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method)
2022-04-27T02:37:36.0242334Z AssertionError: "
2022-04-27T02:37:36.0242718Z ValueError: legitimate exception from process 0$" does not match "process 1 terminated with signal SIGBUS"

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

- c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::FuncTorchDynamicLayerBackMode)
// We still want to copy the Python TLS key - if the inner tensor is a Python tensor, then the wrapper
// should do a TLS snapshot.
.remove(c10::DispatchKey::PythonTLSSnapshot);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD just curious - is it documented somewhere why the PythonTLSSnapshot key lives inside the range of the functorch keys? I thought it used to be higher priority than everything

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PythonTLSSnapshot caused problems with the AOTAutograd <-> functorch interaction. We then decided that functorch "runs first in the dispatcher", and then the regular PyTorch dispatcher runs after. PythonTLSSnapshot is a part of the "regular PyTorch dispatcher", so we put the key after FuncTorchDynamicLayerFront so that functorch is able to interpose first.

More discussion / context can be found at the comment here: pytorch/functorch#732 (comment)

I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch.

The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included.

For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend.

That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere.




[ghstack-poisoned]
Comment on lines +26 to +34
// What's happening here? In general, functional wrappers *should* copy the keys from their inner tensor.
// We need this for the backend use case (LTC/XLA), where any functionalities that the backend uses,
// like autograd or autocast, need to run *directly* on the wrapper.
// However, all of the keys corresponding to functorch transforms should not be copied over.
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
// to participate in the functorch transforms.
auto keys_to_not_copy =
c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::FuncTorchDynamicLayerFrontMode)
- c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::FuncTorchDynamicLayerBackMode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want an explicit denylist with e.g. FuncTorchBatched, FuncTorchGradWrapper, and maybe the sparse keys(?) and the Lazy key (those are also 1:1 with their respective TensorImpl subclasses, right?)

The Autograd keys are also in between DynamicLayerFront and DynamicLayerBack and I'm not sure why you wanted to exclude them

Comment on lines +29 to +31
// However, all of the keys corresponding to functorch transforms should not be copied over.
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
// to participate in the functorch transforms.
Copy link
Contributor

@zou3519 zou3519 Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify: the problem is when there exists an invariant that if a Tensor has a Dispatch Key, then it must be a specific TensorImpl. This occurs a lot in functorch (BatchedTensorImpl <-> FuncTorchBatched key, TensorWrapper <-> FuncTorchGradWrapper key), but is not a functorch-specific problem, right?

Here are some things that I thought were the case, but I might be wrong about them since it's been a while since I looked into them:

  • SparseTensorImpl <-> {Some sparse keys}
  • SparseCsrTensorImpl <-> {the csr sparse keys}
  • LTCTensorImpl <-> Lazy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's right - although in practice it isn't a problem for the backend keys, because we're guaranteed to have unwrapped the FunctionalTensorWrapper before hitting the backend.

I put used FuncTorchDynamicLayerFrontMode - FuncTorchDynamicLayerBackMode because I figured it would be unlikely that we'd add another non-backend, non-functorch key that this logic would apply to, and it lets us avoid hardcoding a new list of keys somewhere.

But I totally forgot that the autograd keys fall in that range. So agreed - I'll create a new keyset corresponding to "tensor impl subclass" keys and use that instead (and at that point I may as well lump in the backend keys, even if they don't directly cause problems in this case).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, I didn't realize FunctionalTensorWrapper doesn't have any problems with the backend keys that correspond 1:1 with tensor impl subclasses.

The new keyset including all "tensor impl subclass keys" sounds reasonable; if you wanted to just continue on the original train of thought we could create a "FuncTorchKeySet" keyset and put all the dispatch keys that begin with "FuncTorch" into it and use that instead

bdhirsh added 2 commits April 25, 2022 14:50
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment)

I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch.

The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included.

For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend.

That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere.




[ghstack-poisoned]
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment)

I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch.

The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included.

For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend.

That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere.




[ghstack-poisoned]
// All of the keys corresponding to functorch transforms should not be copied over.
// Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect
// to participate in the functorch transforms.
key_set_ = key_set_ & c10::functorch_transforms_ks;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be key_set_ & ~c10::functorch_transforms_ks or am I reading this wrong? Aren't we trying to remove the functorch keys from key_set ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes 😛 thanks. I over-confidentally yolo'd CI last night without testing locally

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minus the change I commented on. Approving for now for developer velocity

bdhirsh added 2 commits April 26, 2022 15:13
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment)

I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch.

The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included.

For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend.

That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere.




[ghstack-poisoned]
More discussion / context can be found at the comment here: pytorch/functorch#732 (comment)

I couldn't easily test this in core, but I'll be landing the tests in the linked PR above in functorch.

The tldr: `FunctionalTensorImpl` should copy most of the dispatch keys on the inner tensor (for example: autograd, autocast, and PythonTLSSnapshot keys). However, it should not copy over *all* keys. Specifically, any keys that don't know how to handle `FunctionalTensorImpl` directly (usually because they expect a specific subclass, like `BatchedTensorImpl` shouldn't be included.

For now, this only really applies to functorch transforms, and specific backends (like Sparse / XLA). In practice though it only really matters for transforms, since we always unwrap the `FunctionalTensorWrapper` before hitting a backend.

That's why I went the route of using `FuncTorchDynamicLayerFrontMode - FuncTorchDyanicLayerBackmode` to figure out the keys to ignore, instead of hardcoding a list somewhere.




[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Apr 27, 2022

@pytorchbot merge this please

@github-actions
Copy link
Contributor

Hey @bdhirsh.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Apr 27, 2022
Summary:
Pull Request resolved: #76318

Approved by: https://github.com/zou3519

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/aae7b00f7c7b5e28d1fe8974cd02a538f5be0913

Reviewed By: osalpekar

Differential Revision: D35971227

Pulled By: bdhirsh

fbshipit-source-id: 30cb520370dc91d3a9be4c59df7cf294b0ecc7f5
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/217/head branch May 1, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants