Skip to content

Make FakeTensors return meta device within kernel invocation, add FakeTensor op tests#78972

Closed
eellison wants to merge 16 commits intogh/eellison/306/basefrom
gh/eellison/306/head
Closed

Make FakeTensors return meta device within kernel invocation, add FakeTensor op tests#78972
eellison wants to merge 16 commits intogh/eellison/306/basefrom
gh/eellison/306/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Jun 6, 2022

Stack from ghstack (oldest at bottom):

Add FakeTensor Op tests, as well as necessary changes to make the tests pass.

  • Previously, FakeTensor(cpu, ...) would always return cpu if device was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of is_meta() checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. Another example is in _linalg_check_errors. This PR changes FakeTensorMode to store whether or not we are in a kernel invocation, and while that is True, return meta.

  • Extends the error checking in VariableType_* to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined.

  • A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 6, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 37852f4 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@eellison eellison requested review from ezyang and samdow and removed request for mruberry, ngimel and soulitzer June 6, 2022 21:39
… FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors,  a couple meta registrations)





[ghstack-poisoned]
eellison added a commit that referenced this pull request Jun 6, 2022
… op tests

ghstack-source-id: d8c02c8
Pull Request resolved: #78972
@eellison eellison changed the title Make FakeTensors return meta within kerenl invocation, add FakeTensor op tests Make FakeTensors return meta device within kernel invocation, add FakeTensor op tests Jun 6, 2022
@ezyang
Copy link
Contributor

ezyang commented Jun 7, 2022

This doesn't invalidate the pull request but I want to make a higher level point: we are going to need support for dynamic shape soon, and this means that we were actually be using the C++ meta function implementations, for example we are not going to make tensor iterator symbolic aware. Instead we will be using python side implementations to get dynamic shape aware meta-tensor support. As a result it's not that important to make sure fake tensor works with tensor iterator (although it certainly is nice not to Segfault)



class ComplexInputException(Exception):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you black this file, would be nice to have the formatting on its own

if run_impl_check(func):
return op_impl(self, func, *args, **kwargs)

self.in_kernel_invocation = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naïvely I would've expected these to pair up with the no dispatch call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean put inside the no_dispatch or to use _DisableTorchDispatch instead of self.in_kernel_invocation ? I haven't thought too hard about it... maybe would be a good cleanup, although I still need to do a little more thinking/understanding of composability

@eellison
Copy link
Contributor Author

eellison commented Jun 7, 2022

@ezyang good point. there were a lot of issues not just within TE, though, such as copy_ clone, the linalg test failures, Distribution error checking, etc, etc.

Other seemingly innocuous code like aten::outer could also lead to issues:

self.reshape({self.size(0), 1}) * vec2 within no_dispatch would break because self.reshape would get confused at self having a different device (cpu) than dispatch key and something would break.

Relatedly, it would be great to not have to handle something as simple as aten::outer in the a TorchDispatchMode but right now there is no way to decompose it (i will write up more full issue). If a decomposition for it exists in python, then you can just run that, but there is no equivalent way to take a composite kernel in C++ and run each operator with the mode enabled. If you don't wrap the kernel with no_dispatch it will just endlessly recur. I had some luck running the decompositions defined in python for this but coverage isn't all of the way there and it proved unnecessary at least for now to get the tests passing.

@ezyang
Copy link
Contributor

ezyang commented Jun 7, 2022

I'm not sure what the best way to get the test passing is, but our plan on record is to get all of the composites rewritten in python

…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
eellison added a commit that referenced this pull request Jun 8, 2022
… op tests

ghstack-source-id: 692a600
Pull Request resolved: #78972
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
eellison added a commit that referenced this pull request Jun 8, 2022
… op tests

ghstack-source-id: 2a2f344
Pull Request resolved: #78972
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
eellison added a commit that referenced this pull request Jun 8, 2022
… op tests

ghstack-source-id: c19a309
Pull Request resolved: #78972
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
eellison added a commit that referenced this pull request Jun 8, 2022
… op tests

ghstack-source-id: bc8977f
Pull Request resolved: #78972
…on, add FakeTensor op tests"


Add FakeTensor Op tests, as well as necessary changes to make the tests pass. 

- Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. 

- Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. 

- A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)





[ghstack-poisoned]
eellison added a commit that referenced this pull request Jun 8, 2022
… op tests

ghstack-source-id: 623a132
Pull Request resolved: #78972
@eellison
Copy link
Contributor Author

eellison commented Jun 9, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

github-actions bot commented Jun 9, 2022

Hey @eellison.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Jun 10, 2022
… op tests (#78972)

Summary:
Pull Request resolved: #78972

Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/3c5a3ca9e89183ff3b9274fbe589fa205dc86be4

Reviewed By: osalpekar

Differential Revision: D37030335

Pulled By: eellison

fbshipit-source-id: c246709d73bca62690fa86c20d44816907da2541
@facebook-github-bot facebook-github-bot deleted the gh/eellison/306/head branch June 12, 2022 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants