Make FakeTensors return meta device within kernel invocation, add FakeTensor op tests#78972
Make FakeTensors return meta device within kernel invocation, add FakeTensor op tests#78972eellison wants to merge 16 commits intogh/eellison/306/basefrom
Conversation
… op tests [ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 37852f4 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
… FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors, a couple meta registrations) [ghstack-poisoned]
|
This doesn't invalidate the pull request but I want to make a higher level point: we are going to need support for dynamic shape soon, and this means that we were actually be using the C++ meta function implementations, for example we are not going to make tensor iterator symbolic aware. Instead we will be using python side implementations to get dynamic shape aware meta-tensor support. As a result it's not that important to make sure fake tensor works with tensor iterator (although it certainly is nice not to Segfault) |
|
|
||
|
|
||
| class ComplexInputException(Exception): | ||
| pass |
There was a problem hiding this comment.
Did you black this file, would be nice to have the formatting on its own
| if run_impl_check(func): | ||
| return op_impl(self, func, *args, **kwargs) | ||
|
|
||
| self.in_kernel_invocation = True |
There was a problem hiding this comment.
Naïvely I would've expected these to pair up with the no dispatch call
There was a problem hiding this comment.
do you mean put inside the no_dispatch or to use _DisableTorchDispatch instead of self.in_kernel_invocation ? I haven't thought too hard about it... maybe would be a good cleanup, although I still need to do a little more thinking/understanding of composability
|
@ezyang good point. there were a lot of issues not just within TE, though, such as Other seemingly innocuous code like self.reshape({self.size(0), 1}) * vec2 within no_dispatch would break because Relatedly, it would be great to not have to handle something as simple as |
|
I'm not sure what the best way to get the test passing is, but our plan on record is to get all of the composites rewritten in python |
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
…on, add FakeTensor op tests" Add FakeTensor Op tests, as well as necessary changes to make the tests pass. - Previously, `FakeTensor(cpu, ...)` would always return `cpu` if `device` was called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number of `is_meta()` checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. [Another example is in _linalg_check_errors](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BatchLinearAlgebra.cpp#L1298). This PR changes `FakeTensorMode` to store whether or not we are in a kernel invocation, and while that is True, return `meta`. - Extends the error checking in `VariableType_*` to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined. - A few small misc changes (adding a couple new ops to be special-handled in FakeTensors) [ghstack-poisoned]
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @eellison. |
… op tests (#78972) Summary: Pull Request resolved: #78972 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/3c5a3ca9e89183ff3b9274fbe589fa205dc86be4 Reviewed By: osalpekar Differential Revision: D37030335 Pulled By: eellison fbshipit-source-id: c246709d73bca62690fa86c20d44816907da2541
Stack from ghstack (oldest at bottom):
Add FakeTensor Op tests, as well as necessary changes to make the tests pass.
Previously,
FakeTensor(cpu, ...)would always returncpuifdevicewas called. While this the behavior you want within userland code, within the kernel itself you want to compute as if everything is a meta tensor. For instance, within TensorIterator there are a number ofis_meta()checks that before this PR would return False, and attempt to proceed as if the FakeTensors were on CPU instead of meta, which led to segfaults among other incorrect behavior. Another example is in _linalg_check_errors. This PR changesFakeTensorModeto store whether or not we are in a kernel invocation, and while that is True, returnmeta.Extends the error checking in
VariableType_*to ignore Tensors with torch_dispatch defined, in addition to the existing when a torch_mode is defined.A few small misc changes (adding a couple new ops to be special-handled in FakeTensors)