Refactor TensorIterator to do allocations via MetaBase::set_output#48659
Refactor TensorIterator to do allocations via MetaBase::set_output#48659ezyang wants to merge 4 commits intogh/ezyang/875/basefrom
Conversation
Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function set_output. This nicely centralizes restriding logic and eliminates the need for a separate named tensor propagation pass. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function set_output. This nicely centralizes restriding logic and eliminates the need for a separate named tensor propagation pass. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: 65f9fb4 Pull Request resolved: #48659
💊 CI failures summary and remediationsAs of commit 6969196 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 14 times. |
…t_output" Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function set_output. This nicely centralizes restriding logic and eliminates the need for a separate named tensor propagation pass. Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function set_output. This nicely centralizes restriding logic and eliminates the need for a separate named tensor propagation pass. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: f7f16f6 Pull Request resolved: #48659
Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function set_output. This nicely centralizes restriding logic and eliminates the need for a separate named tensor propagation pass. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: f7f16f6 Pull Request resolved: pytorch#48659
…t_output" Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function TensorIterator::set_output. This nicely centralizes restriding logic and mostly eliminates the need for a separate named tensor propagation pass. The one exception is for inplace operations (`add_`), where previously we never actually call `set_output` when we determine resizing is not necessary; there's an extra propagate names in `allocate_or_resize_outputs` to handle this case (I audited all other `set_output` sites and found that we always hit this path in that situation). Although hypothetically this could cause problems for structured kernels (which require a `set_output` call in all cases), this codepath is irrelevant for structured kernels as a TensorIterator will never be constructed with an explicit out argument (remember, structured kernels handle out/functional/inplace variants). There's also a tricky case in `compute_types`; check the comments there for more details. * Split TensorIterator into a TensorIteratorBase, which contains most of the logic but doesn't define `set_output`. A decent chunk of the diff is just the mechanical rename of TensorIterator to TensorIteratorBase. However, there are a few cases where we create fresh TensorIterator objects from another TensorIterator. In those cases, we always construct a fresh TensorIterator (rather than preserving the subclass of TensorIteratorBase that induced this construction). This makes sense, because a structured function class will contain metadata that isn't relevant for these downstream uses. This is done by *intentionally* permitting object slicing with the `TensorIterator(const TensorIteratorBase&)` constructor. * Introduce a new `MetaBase` class which contains the canonical virtual method definition for `set_output`. This will allow structured classes to make use of it directly without going through TensorIterator (not in this PR). Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function set_output. This nicely centralizes restriding logic and eliminates the need for a separate named tensor propagation pass. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: 999f9c0 Pull Request resolved: pytorch#48659
…t_output" Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function TensorIterator::set_output. This nicely centralizes restriding logic and mostly eliminates the need for a separate named tensor propagation pass. The one exception is for inplace operations (`add_`), where previously we never actually call `set_output` when we determine resizing is not necessary; there's an extra propagate names in `allocate_or_resize_outputs` to handle this case (I audited all other `set_output` sites and found that we always hit this path in that situation). Although hypothetically this could cause problems for structured kernels (which require a `set_output` call in all cases), this codepath is irrelevant for structured kernels as a TensorIterator will never be constructed with an explicit out argument (remember, structured kernels handle out/functional/inplace variants). There's also a tricky case in `compute_types`; check the comments there for more details. * Split TensorIterator into a TensorIteratorBase, which contains most of the logic but doesn't define `set_output`. A decent chunk of the diff is just the mechanical rename of TensorIterator to TensorIteratorBase. However, there are a few cases where we create fresh TensorIterator objects from another TensorIterator. In those cases, we always construct a fresh TensorIterator (rather than preserving the subclass of TensorIteratorBase that induced this construction). This makes sense, because a structured function class will contain metadata that isn't relevant for these downstream uses. This is done by *intentionally* permitting object slicing with the `TensorIterator(const TensorIteratorBase&)` constructor. * Introduce a new `MetaBase` class which contains the canonical virtual method definition for `set_output`. This will allow structured classes to make use of it directly without going through TensorIterator (not in this PR). Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
bhosmer
left a comment
There was a problem hiding this comment.
Looks good!
Only real question is about some minorly redundant logic that I think fell out of the refactoring straddle between allocate_or_resize_outputs and fast_set_up call sites. But I'd buy that this is the right sweet spot, at least for now.
| if (common_device == kCPU) { | ||
| // Casts to outputs by creating temporaries of the correct dtype (if needed) | ||
| if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_) { | ||
| TORCH_INTERNAL_ASSERT(op.tensor.defined()); |
There was a problem hiding this comment.
I'm seeing the continue on line 372... not seeing how we get here with op.tensor undefined, but maybe I'm missing it
There was a problem hiding this comment.
You can't get here with op.tensor() being undefined, that's why it's an assert :) Actually I missed the continue on line 372, but I was fairly confident (for other reasons) that it was impossible, but not 100% so.
| auto& op = operands_[output_idx]; | ||
| TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); | ||
| if (!op.tensor.defined()) { | ||
| TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", output_idx); |
There was a problem hiding this comment.
On a quick read it looks like a less aggressive factoring would avoid some redundant tests when this is called from allocate_or_resize_outputs, at the cost of leaving more code in fast_set_up - but I'm assuming you figure the slow path can afford it (or that there's some API requirement that it be done in set_output that I'm missing)
There was a problem hiding this comment.
This is a good question. When I wrote this to start, I tried to assume as few preconditions on the call-site as possible, ergo the redundant tests. If I understand you correctly, it's just this assert that could be removed (I don't think any of the other asserts are removable). Seems reasonable.
| // of a lot of boilerplate above | ||
| TensorIterator build() { | ||
| return TensorIterator(*this); | ||
| TensorIterator iter; |
There was a problem hiding this comment.
did you kill the TensorIterator(TensorIteratorConfig&) constructor just to clean things up?
There was a problem hiding this comment.
You are right, I don't think I was supposed to delete this constructor, and I ended up just refactoring around the problem. Not sure if it's worth restoring it though!
There was a problem hiding this comment.
Yeah I think this is better - more transparent. The other just hid the build call
Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function set_output. This nicely centralizes restriding logic and eliminates the need for a separate named tensor propagation pass. Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: d403964 Pull Request resolved: pytorch#48659
|
@bhosmer I merged this because I got a clean merge window but I still owe you updates for your comments. Coming soon. |
Signed-off-by: Edward Z. Yang <ezyang@fb.com> [ghstack-poisoned]
…ytorch#48659) Summary: Pull Request resolved: pytorch#48659 Detailed RFC at https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator What this diff does: * Refactor allocation of outputs in TensorIterator into a call to a single function TensorIterator::set_output. This nicely centralizes restriding logic and mostly eliminates the need for a separate named tensor propagation pass. The one exception is for inplace operations (`add_`), where previously we never actually call `set_output` when we determine resizing is not necessary; there's an extra propagate names in `allocate_or_resize_outputs` to handle this case (I audited all other `set_output` sites and found that we always hit this path in that situation). Although hypothetically this could cause problems for structured kernels (which require a `set_output` call in all cases), this codepath is irrelevant for structured kernels as a TensorIterator will never be constructed with an explicit out argument (remember, structured kernels handle out/functional/inplace variants). There's also a tricky case in `compute_types`; check the comments there for more details. * Split TensorIterator into a TensorIteratorBase, which contains most of the logic but doesn't define `set_output`. A decent chunk of the diff is just the mechanical rename of TensorIterator to TensorIteratorBase. However, there are a few cases where we create fresh TensorIterator objects from another TensorIterator. In those cases, we always construct a fresh TensorIterator (rather than preserving the subclass of TensorIteratorBase that induced this construction). This makes sense, because a structured function class will contain metadata that isn't relevant for these downstream uses. This is done by *intentionally* permitting object slicing with the `TensorIterator(const TensorIteratorBase&)` constructor. * Introduce a new `MetaBase` class which contains the canonical virtual method definition for `set_output`. This will allow structured classes to make use of it directly without going through TensorIterator (not in this PR). Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D25261844 Pulled By: ezyang fbshipit-source-id: 34a9830cccbc07eaaf7c4f75114cd00953e3db7d
Signed-off-by: Edward Z. Yang <ezyang@fb.com> ghstack-source-id: 1de9cce Pull Request resolved: pytorch#48731
Summary: Pull Request resolved: #48731 Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D25278034 Pulled By: ezyang fbshipit-source-id: 73652311b48d8d80c06e9385b7ff18ef3a158ae8
Stack from ghstack:
Detailed RFC at
https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#handling-tensoriterator
What this diff does:
add_), where previously we never actually callset_outputwhen we determine resizing is not necessary; there's an extra propagate names inallocate_or_resize_outputsto handle this case (I audited all otherset_outputsites and found that we always hit this path in that situation). Although hypothetically this could cause problems for structured kernels (which require aset_outputcall in all cases), this codepath is irrelevant for structured kernels as a TensorIterator will never be constructed with an explicit out argument (remember, structured kernels handle out/functional/inplace variants). There's also a tricky case incompute_types; check the comments there for more details.set_output. A decent chunk of the diff is just the mechanical rename of TensorIterator to TensorIteratorBase. However, there are a few cases where we create fresh TensorIterator objects from another TensorIterator. In those cases, we always construct a fresh TensorIterator (rather than preserving the subclass of TensorIteratorBase that induced this construction). This makes sense, because a structured function class will contain metadata that isn't relevant for these downstream uses. This is done by intentionally permitting object slicing with theTensorIterator(const TensorIteratorBase&)constructor.MetaBaseclass which contains the canonical virtual method definition forset_output. This will allow structured classes to make use of it directly without going through TensorIterator (not in this PR).Signed-off-by: Edward Z. Yang ezyang@fb.com
Differential Revision: D25261844