generate out= and functional variants of NativeFunctions, get functionalization to work for all mutable ops#76320
generate out= and functional variants of NativeFunctions, get functionalization to work for all mutable ops#76320bdhirsh wants to merge 17 commits intogh/bdhirsh/219/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
❌ 1 New FailuresAs of commit 8e2b6db (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
| "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", | ||
| "div.Scalar(Tensor self, Scalar other) -> Tensor", | ||
| "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", | ||
| "_fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, Tensor output, Tensor mask)" # only used by the functionalization pass # noqa:B950 |
There was a problem hiding this comment.
We don't want to expose this to python since it's really just used for functionalization.
It also happens to be impossible to expose to python though - the mutable vs. functional versions of this op look identical w.r.t. python (since they have the same base name, args and kwargs. The mutability info isn't part of the python schema information)
| fake_quant_on, | ||
| // Careful - the functional and non-functional version of this operator have the same C++ API name. | ||
| // The only difference in their C++ schemas is that the non-functional version takes in non-const tensors. | ||
| // (If we call into the funtional version, we'll infinite loop). |
There was a problem hiding this comment.
This feels kinda unfortunate. I figure it's alright though because:
- This comes up only in very uncommon cases specific to functionalization
- We're already tied in to using
const Tensorvs.Tensoras a way to distinguish function schemas in other parts of the codebase (I remember this coming up in the faithful vs. nonfaithful C++ API overloads, where the out= argument can come either first or last).
There was a problem hiding this comment.
IMO, this seems bad, especially because we want to STOP using mutable arguments to distinguish between mutable and non-mutable inputs. I think we should come up with some naming convention (similar to how we do _out for out overloads) to make sure these are distinguished in C++.
There was a problem hiding this comment.
That seems reasonable (although I feel a little bad adding some special handling just for functionalization in the code that spits out C++ API names).
What do you think of foo.functional -> we check if the overload name is functional and make the name foo_functional? (Basically specializing on both out and functional now, instead of just out)
There was a problem hiding this comment.
What do you think of foo.functional -> we check if the overload name is functional and make the name foo_functional? (Basically specializing on both out and functional now, instead of just out)
Yes, this seems reasonable.
| # Then if a corresponding out-of-place version exists, we expect it to have the following schema: | ||
| # foo.functional(Tensor input) -> (Tensor, Tensor) | ||
| # The first tensor(s) in the output should all correspond to the newly updated inputs. | ||
| # The last tensor(s) in the output should correspond to the original outputs of the function. |
There was a problem hiding this comment.
So... I don't know of any easy way to assert the "first / last" part of this comment. The assert below only checks that the # of returns is correct.
I think I'm ok with this though mostly because you would only add one of these functional ops in the first place because you wanted it to work with functionalization, and if you returned arguments in the wrong order then functionalization would immediately give you garbage results.
I guess if we added first class support for OpInfo <> functionalization testing, and just ran it automatically on all OpInfos, then we could catch it.
There was a problem hiding this comment.
The fix is easy: put names in the returns.
There was a problem hiding this comment.
Well technically, that still doesn't prevent the person implementing the kernel from returning a tuple of tensors with some of the tensors swizzled around (not obeying whatever ordering they implied in the schema based on the names).
The names we give to the returned tensors in the schema can't be the same as the input tensors - since they're fresh, non-aliased tensors (although we could maybe enforce some convention like "foo.functional(Tensor input1, Tensor inpt2) -> (Tensor inpt1_out, Tensor inpt2_out, Tensor)
There was a problem hiding this comment.
Well technically, that still doesn't prevent the person implementing the kernel from returning a tuple of tensors with some of the tensors swizzled around (not obeying whatever ordering they implied in the schema based on the names).
Well, sure, but there's no way to solve this. I'll settle for making sure the functional and regular variants agree haha.
The names we give to the returned tensors in the schema can't be the same as the input tensors - since they're fresh, non-aliased tensors (although we could maybe enforce some convention like "foo.functional(Tensor input1, Tensor inpt2) -> (Tensor inpt1_out, Tensor inpt2_out, Tensor)
A naming convention sounds like the right thing here.
…l-only args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
| inner_ret: str = f'std::get<{i}>({inner_out_name})' if return_is_tuple else inner_out_name | ||
| updates.append(f"""\ | ||
| auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});""") | ||
| return_names.append(f'output_{i}') |
There was a problem hiding this comment.
Based on inspecting the generated code, a std::move should be possible here
|
Generated code is quite a handful, but looks good |
…args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
…args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
test/test_functionalization.py
Outdated
| logs = self.get_logs(f, torch.ones(1)) | ||
| self.assertExpectedInline('\n'.join(logs), """\ | ||
| $0 = input('input') | ||
| $1, $2, $3, $4, $5, $6 = torch._ops.aten._fused_moving_avg_obs_fq_helper.functional($0, $0, $0, $0, $0, $0, $0, 1.0, 1, 1, 0)""") |
There was a problem hiding this comment.
cc @anijain2305, this is what the graph will look like when you functionalize _fused_moving_avg_obs_fq_helper (all of the mutable args in the original op become outputs)
| """ | ||
| is_inplace = self.name.name.inplace | ||
| is_out = bool(self.arguments.out) | ||
| is_inplace = self.name.name.inplace |
There was a problem hiding this comment.
does this... actually do anything?
There was a problem hiding this comment.
looks like it's just used below:
if is_inplace:
return SchemaKind.inplace
but agreed it feels unnecessary
| *, | ||
| strip_default: bool = False, | ||
| convert_mutable_inputs_to_returns: bool = False, | ||
| strip_view_copy_name: bool = False, |
There was a problem hiding this comment.
this is my "too many kwargs" grumpy face
torchgen/model.py
Outdated
| base_name = base_name.replace("_copy", "") | ||
|
|
||
| if convert_mutable_inputs_to_returns: | ||
| # find mutable inputs that are not originally returned, and conver them to returns |
torchgen/model.py
Outdated
| if convert_mutable_inputs_to_returns: | ||
| # find mutable inputs that are not originally returned, and conver them to returns | ||
| returns_from_mutable_inputs = tuple( | ||
| Return(name=None, type=a.type, annotation=None) |
There was a problem hiding this comment.
Sure you want no name here? It sure seems like it would be useful for the test case you gave above
There was a problem hiding this comment.
oh yeah - although I think I actually want no names here, that way I can properly pair up the functional + mutable ops and emit a proper error if someone didn't name the outputs properly. (If I add names here, then we would just silently not pair up the two ops instead of emitting an error)
torchgen/model.py
Outdated
| def view_signature(self) -> "FunctionSchema": | ||
| return self.signature(strip_view_copy_name=True) | ||
|
|
||
| def self_to_out_signature(self) -> "FunctionSchema": |
There was a problem hiding this comment.
These methods are very long but they are only called from one place, so it would make me happier if they were more private rather than being directly on FunctionSchema object. If you take my advice and move generate_function out of model.py these will naturally go too.
| pre_tensor_options_kwarg_only=self.pre_tensor_options_kwarg_only, | ||
| tensor_options=self.tensor_options, | ||
| post_tensor_options_kwarg_only=self.post_tensor_options_kwarg_only, | ||
| out=tuple(outs), |
There was a problem hiding this comment.
ah, the pain of purely functional programming lol
There was a problem hiding this comment.
I'd be OK with using dataclasses.replace to make this less painful (you'll need to shut up mypy about it, it won't be able to figure out the typing)
torchgen/model.py
Outdated
| returns=self.returns, | ||
| ) | ||
|
|
||
| def mutable_to_out_signature(self) -> "FunctionSchema": |
There was a problem hiding this comment.
An example before after in the code here also helpful
torchgen/gen.py
Outdated
| ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) | ||
|
|
||
|
|
||
| def pre_group_native_functions( |
There was a problem hiding this comment.
This was pre-existing and I moved it further up in the file, but you're right it deserves a comment either way
torchgen/gen.py
Outdated
| # (1) Allows us to more consistently group together variants into NativeFunctionsGroup objects. | ||
| # (2) Gives the functionalization pass functional variants to work with, so we don't need to | ||
| # manually implement functional variants of all operators to get support for all mutable operators. | ||
| def add_generated_native_functions( |
There was a problem hiding this comment.
when combined with the model.py helpers, consider giving this its own module
torchgen/gen.py
Outdated
| # manually implement functional variants of all operators to get support for all mutable operators. | ||
| def add_generated_native_functions( | ||
| rs: List[NativeFunction], | ||
| indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], |
There was a problem hiding this comment.
Warning: this input gets mutated! It might be less misleading to mutate both inputs, or none of them.
There was a problem hiding this comment.
good point - the mutating feels ok to do since it should basically be the first thing that happens after yaml parsing (so anyone using codegen sees the final result), but I'll also mutate rs for consistency (and return void)
ezyang
left a comment
There was a problem hiding this comment.
I only skimmed torchgen/gen_functionalization_type.py but everything looks good
albanD
left a comment
There was a problem hiding this comment.
(only looked at the autograd codegen)
torchgen/api/autograd.py
Outdated
| return functional_info_by_signature[f_sig], False | ||
|
|
||
| # (3) Some operators have a derivative explicitly defined for the mutable | ||
| # variant, but get a code-generated out-of-place variant. |
There was a problem hiding this comment.
I'll make this comment a bit clearer. I mean that:
- there's a mutable variant of the operator in native_functions.yaml, that has a derivative entry
- there's a code-generated functional variant
- that code-generated functional variant does not have a derivative formula, so we want to re-use the existing mutable variant's formula
torchgen/api/autograd.py
Outdated
| # variant, but get a code-generated out-of-place variant. | ||
| # Use that if available | ||
| if "generated" in f.tags and f_sig in non_functional_info_by_signature: | ||
| return non_functional_info_by_signature[f_sig], False |
There was a problem hiding this comment.
We might want to be a bit careful here.
In particular, the later code will do some adjustment to the formula if it is not an exact match.
This code is assuming that we go functional formula -> inplace formula right now.
So I think there are two things here:
- We want the second return to say if the not-exact match is one way or the other
- When going from an inplace -> out of place formula. We should make sure that
selfis not used (ignore out variants here). Otherwise, we need to replaceselfbyresultin the formula.
There was a problem hiding this comment.
thanks!
I checked and at least today, there are no derivative formulas that rely on the new logic that have make use of self. I'm going to make this an error for now.
There was a problem hiding this comment.
We want the second return to say if the not-exact match is one way or the other
It looks like today, the logic for is_exact_match=False only actually runs when the function's schema kind is SchemaKind.inplace.
Why do we want the second return (is_exact_match) to be able to specify if we're going functional->mutable or mutable->functional? Is there an actual change that we need to make to the derivative formula to get it to work for the functional variant? Or do you just think it's clearer at call-site to be able to distinguish those cases, even if we don't need to today (I'd buy that argument too).
| def should_generate_py_binding(f: NativeFunction) -> bool: | ||
| # So far, all NativeFunctions that are entirely code-generated do not get python bindings. | ||
| if 'generated' in f.tags: | ||
| return False |
There was a problem hiding this comment.
The main use case I can see right now is to build functional/output layout modes. And both would work at the torch.ops level which don't require these bindings.
…get functionalization to work for all mutable ops" This PR is pretty large, but it's motivated by the following idea: - **every** mutable operators in aten should be functionalize-able - **every** mutable operator should have a functional + out= variant, so our codegen can operate on it in a more structured way (and full out= coverage support is probably useful for mobile, for memory planning) ### The main changes - Introduce a new `SchemaKind.mutable` enum in the codegen - Update the codegen grouping logic to properly group all functional/inplace/out=/mutable variants today (and add a bunch of error checks and restrictions to tighten up the set of schemas that we allow into native_functions.yaml) - automatically generate some new `NativeFunctions` in the codegen (!!). Under certain conditions, we generate `functional` and `out=` variants of some existing aten operators - code-generate `mutable` -> `functional` kernels for any of the newly generated `functional` NativeFunction objects. - Clean up functionalization codegen, now that it can rely on the existing grouping logic - clean up LTC to only write lowerings for functional ops (we can do this now that every mutable operator has a functional equivalent. Generating all of these new `NativeFunction`'s is a pretty big change - up until now, every operator in aten was explicitly spelled out in `NativeFunctions.yaml`. This seems more ok to do now, because - we now have a `torchgen` package that you can install, and use to dynamically inspect all of the aten ops used in code generation - There are just so many functional / out= ops that are missing, and adding them all manually would be a massive amount of boilerplate A lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for. I listed out the full set of new `NativeFunctions` at the bottom of this description. It also shouldn't be too hard to add generated `foo.scratch` of out= operators on top of this, if we decide that's useful. ### Enumeration of changes / what order to look at things (1) I would recommend starting with the updated versions of `FunctionSchema.signature()` and `Arguments.signature()` in `model.py`. This is the main, core change to our operator grouping logic, that lets us always group `functional/inplace/mutable/out=` ops together; a lot of the other changes follow from it. In it, we: - Convert **mutable** (`post_self_positional` args) to returns (which come **after** any of the original returns) - drop `TensorOptions` args (this lets us properly group the existing out= factory ops) in `FunctionSchema.__post_init__()`, I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions). (2) Next, the code for generating `functional` + `out=` NativeFunctions - In `gen.py`, `add_generated_native_functions()` has the logic for deciding when to generate new `NativeFunction` objects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful of `functional` ops that don't have `out=` variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add. Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them The code that actually generates new NativeFunctions is `generate_function` in `model.py`. Given a "base function" of one `SchemaKind`, and a target `SchemaKind`, it generates a new `NativeFunction` with the target schema. For now, we only actually use it with functional / out= as the target schema. (3) Generating functional kernels in terms of their existing mutable variants. This happens in `gen_composite_functional_kernel` in `gen_functionalization.py`. I had to modify `translate()` to be able to remove const-ness when calling a mutable op from a functional op. (4) updating the functionalization codegen in a few ways: - We now have full support for all mutable -> functional op transformations. Including weird `SchemaKind.mutable` ops like `_fused_moving_avg_obs_fq_helper` -> `_fused_moving_avg_obs_fq_helper.functional`, and out= factory ops like`range.start_out` -> `range.start_step`. For `SchemaKind.mutable` ops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to update `translate()` to know that it could grab TensorOptions arguments from the `out` tensor in the calling context. - I removed the side-car mapping of mutable -> functional ops, so we now rely fully on the normal `NativeFunctionsGroup` groupings. I still ended up passing ungrouped `NativeFunctions` into the functionalization codegen for 2 reasons: (a) We need to register `CompositeImplicitAutograd` kernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions) (b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though. (5) Updating the LazyTensor codegen LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team. ### Full list of newly generated `NativeFunction` objects new functional ops count: 74 new out= ops count: 97 total new ops: 171 ``` _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_functional(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.float_functional(Tensor self, float p=0.5, *, Generator? generator=None) -> Tensor copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) embedding_renorm.functional(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor resize.out(Tensor self, int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize.functional(Tensor self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.functional(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) _mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize_as.functional(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) resize_as_sparse.functional(Tensor self, Tensor the_template) -> Tensor zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) zero.functional(Tensor self) -> Tensor sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize_and_clear.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor _coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) _coalesced.functional(Tensor self, bool coalesced) -> Tensor copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) copy_sparse_to_sparse.functional(Tensor self, Tensor src, bool non_blocking=False) -> Tensor _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_functional(Tensor self, Storage source) -> Tensor set.source_Storage_storage_offset_out(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[], *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_storage_offset_functional(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) set.source_Tensor_functional(Tensor self, Tensor source) -> Tensor set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) set.functional(Tensor self) -> Tensor masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.from_functional(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.to_functional(Tensor self, int to, *, Generator? generator=None) -> Tensor random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.functional(Tensor self, *, Generator? generator=None) -> Tensor uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) uniform.functional(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) cauchy.functional(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) log_normal.functional(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) exponential.functional(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) geometric.functional(Tensor self, float p, *, Generator? generator=None) -> Tensor normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) normal.functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor _amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () _amp_foreach_non_finite_check_and_unscale.functional(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) _amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) _amp_update_scale.functional(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) _foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_add.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_sub.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_mul.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_div.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_add.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_sub.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_mul.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_div.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_add.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_sub.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_div.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_mul.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_zero.functional(Tensor[] self) -> Tensor[] self_out _foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_exp.functional(Tensor[] self) -> Tensor[] self_out _foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sqrt.functional(Tensor[] self) -> Tensor[] self_out _foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_abs.functional(Tensor[] self) -> Tensor[] self_out _foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_acos.functional(Tensor[] self) -> Tensor[] self_out _foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_asin.functional(Tensor[] self) -> Tensor[] self_out _foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_atan.functional(Tensor[] self) -> Tensor[] self_out _foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_ceil.functional(Tensor[] self) -> Tensor[] self_out _foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cos.functional(Tensor[] self) -> Tensor[] self_out _foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cosh.functional(Tensor[] self) -> Tensor[] self_out _foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erf.functional(Tensor[] self) -> Tensor[] self_out _foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erfc.functional(Tensor[] self) -> Tensor[] self_out _foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_expm1.functional(Tensor[] self) -> Tensor[] self_out _foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_floor.functional(Tensor[] self) -> Tensor[] self_out _foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log.functional(Tensor[] self) -> Tensor[] self_out _foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log10.functional(Tensor[] self) -> Tensor[] self_out _foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log1p.functional(Tensor[] self) -> Tensor[] self_out _foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log2.functional(Tensor[] self) -> Tensor[] self_out _foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_neg.functional(Tensor[] self) -> Tensor[] self_out _foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tan.functional(Tensor[] self) -> Tensor[] self_out _foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tanh.functional(Tensor[] self) -> Tensor[] self_out _foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sin.functional(Tensor[] self) -> Tensor[] self_out _foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sinh.functional(Tensor[] self) -> Tensor[] self_out _foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_round.functional(Tensor[] self) -> Tensor[] self_out _foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_lgamma.functional(Tensor[] self) -> Tensor[] self_out _foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_frac.functional(Tensor[] self) -> Tensor[] self_out _foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_reciprocal.functional(Tensor[] self) -> Tensor[] self_out _foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sigmoid.functional(Tensor[] self) -> Tensor[] self_out _foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_trunc.functional(Tensor[] self) -> Tensor[] self_out _foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcdiv.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcmul.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcdiv.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcmul.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _linalg_inv_out_helper.out(Tensor self, Tensor(b!) infos_lu, Tensor(c!) infos_getri, *, Tensor(a!) out) -> Tensor(a!) _linalg_inv_out_helper.functional(Tensor self, Tensor infos_lu, Tensor infos_getri) -> (Tensor, Tensor infos_lu_out, Tensor infos_getri_out) ``` [ghstack-poisoned]
…get functionalization to work for all mutable ops" This PR is pretty large, but it's motivated by the following idea: - **every** mutable operators in aten should be functionalize-able - **every** mutable operator should have a functional + out= variant, so our codegen can operate on it in a more structured way (and full out= coverage support is probably useful for mobile, for memory planning) ### The main changes - Introduce a new `SchemaKind.mutable` enum in the codegen - Update the codegen grouping logic to properly group all functional/inplace/out=/mutable variants today (and add a bunch of error checks and restrictions to tighten up the set of schemas that we allow into native_functions.yaml) - automatically generate some new `NativeFunctions` in the codegen (!!). Under certain conditions, we generate `functional` and `out=` variants of some existing aten operators - code-generate `mutable` -> `functional` kernels for any of the newly generated `functional` NativeFunction objects. - Clean up functionalization codegen, now that it can rely on the existing grouping logic - clean up LTC to only write lowerings for functional ops (we can do this now that every mutable operator has a functional equivalent. Generating all of these new `NativeFunction`'s is a pretty big change - up until now, every operator in aten was explicitly spelled out in `NativeFunctions.yaml`. This seems more ok to do now, because - we now have a `torchgen` package that you can install, and use to dynamically inspect all of the aten ops used in code generation - There are just so many functional / out= ops that are missing, and adding them all manually would be a massive amount of boilerplate A lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for. I listed out the full set of new `NativeFunctions` at the bottom of this description. It also shouldn't be too hard to add generated `foo.scratch` of out= operators on top of this, if we decide that's useful. ### Enumeration of changes / what order to look at things (1) I would recommend starting with the updated versions of `FunctionSchema.signature()` and `Arguments.signature()` in `model.py`. This is the main, core change to our operator grouping logic, that lets us always group `functional/inplace/mutable/out=` ops together; a lot of the other changes follow from it. In it, we: - Convert **mutable** (`post_self_positional` args) to returns (which come **after** any of the original returns) - drop `TensorOptions` args (this lets us properly group the existing out= factory ops) in `FunctionSchema.__post_init__()`, I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions). (2) Next, the code for generating `functional` + `out=` NativeFunctions - In `gen.py`, `add_generated_native_functions()` has the logic for deciding when to generate new `NativeFunction` objects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful of `functional` ops that don't have `out=` variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add. Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them The code that actually generates new NativeFunctions is `generate_function` in `model.py`. Given a "base function" of one `SchemaKind`, and a target `SchemaKind`, it generates a new `NativeFunction` with the target schema. For now, we only actually use it with functional / out= as the target schema. (3) Generating functional kernels in terms of their existing mutable variants. This happens in `gen_composite_functional_kernel` in `gen_functionalization.py`. I had to modify `translate()` to be able to remove const-ness when calling a mutable op from a functional op. (4) updating the functionalization codegen in a few ways: - We now have full support for all mutable -> functional op transformations. Including weird `SchemaKind.mutable` ops like `_fused_moving_avg_obs_fq_helper` -> `_fused_moving_avg_obs_fq_helper.functional`, and out= factory ops like`range.start_out` -> `range.start_step`. For `SchemaKind.mutable` ops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to update `translate()` to know that it could grab TensorOptions arguments from the `out` tensor in the calling context. - I removed the side-car mapping of mutable -> functional ops, so we now rely fully on the normal `NativeFunctionsGroup` groupings. I still ended up passing ungrouped `NativeFunctions` into the functionalization codegen for 2 reasons: (a) We need to register `CompositeImplicitAutograd` kernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions) (b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though. (5) Updating the LazyTensor codegen LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team. ### Full list of newly generated `NativeFunction` objects new functional ops count: 74 new out= ops count: 97 total new ops: 171 ``` _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_functional(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.float_functional(Tensor self, float p=0.5, *, Generator? generator=None) -> Tensor copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) embedding_renorm.functional(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor resize.out(Tensor self, int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize.functional(Tensor self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.functional(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) _mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize_as.functional(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) resize_as_sparse.functional(Tensor self, Tensor the_template) -> Tensor zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) zero.functional(Tensor self) -> Tensor sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize_and_clear.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor _coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) _coalesced.functional(Tensor self, bool coalesced) -> Tensor copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) copy_sparse_to_sparse.functional(Tensor self, Tensor src, bool non_blocking=False) -> Tensor _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_functional(Tensor self, Storage source) -> Tensor set.source_Storage_storage_offset_out(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[], *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_storage_offset_functional(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) set.source_Tensor_functional(Tensor self, Tensor source) -> Tensor set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) set.functional(Tensor self) -> Tensor masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.from_functional(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.to_functional(Tensor self, int to, *, Generator? generator=None) -> Tensor random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.functional(Tensor self, *, Generator? generator=None) -> Tensor uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) uniform.functional(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) cauchy.functional(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) log_normal.functional(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) exponential.functional(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) geometric.functional(Tensor self, float p, *, Generator? generator=None) -> Tensor normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) normal.functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor _amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () _amp_foreach_non_finite_check_and_unscale.functional(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) _amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) _amp_update_scale.functional(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) _foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_add.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_sub.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_mul.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_div.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_add.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_sub.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_mul.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_div.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_add.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_sub.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_div.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_mul.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_zero.functional(Tensor[] self) -> Tensor[] self_out _foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_exp.functional(Tensor[] self) -> Tensor[] self_out _foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sqrt.functional(Tensor[] self) -> Tensor[] self_out _foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_abs.functional(Tensor[] self) -> Tensor[] self_out _foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_acos.functional(Tensor[] self) -> Tensor[] self_out _foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_asin.functional(Tensor[] self) -> Tensor[] self_out _foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_atan.functional(Tensor[] self) -> Tensor[] self_out _foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_ceil.functional(Tensor[] self) -> Tensor[] self_out _foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cos.functional(Tensor[] self) -> Tensor[] self_out _foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cosh.functional(Tensor[] self) -> Tensor[] self_out _foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erf.functional(Tensor[] self) -> Tensor[] self_out _foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erfc.functional(Tensor[] self) -> Tensor[] self_out _foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_expm1.functional(Tensor[] self) -> Tensor[] self_out _foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_floor.functional(Tensor[] self) -> Tensor[] self_out _foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log.functional(Tensor[] self) -> Tensor[] self_out _foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log10.functional(Tensor[] self) -> Tensor[] self_out _foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log1p.functional(Tensor[] self) -> Tensor[] self_out _foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log2.functional(Tensor[] self) -> Tensor[] self_out _foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_neg.functional(Tensor[] self) -> Tensor[] self_out _foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tan.functional(Tensor[] self) -> Tensor[] self_out _foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tanh.functional(Tensor[] self) -> Tensor[] self_out _foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sin.functional(Tensor[] self) -> Tensor[] self_out _foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sinh.functional(Tensor[] self) -> Tensor[] self_out _foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_round.functional(Tensor[] self) -> Tensor[] self_out _foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_lgamma.functional(Tensor[] self) -> Tensor[] self_out _foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_frac.functional(Tensor[] self) -> Tensor[] self_out _foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_reciprocal.functional(Tensor[] self) -> Tensor[] self_out _foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sigmoid.functional(Tensor[] self) -> Tensor[] self_out _foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_trunc.functional(Tensor[] self) -> Tensor[] self_out _foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcdiv.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcmul.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcdiv.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcmul.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _linalg_inv_out_helper.out(Tensor self, Tensor(b!) infos_lu, Tensor(c!) infos_getri, *, Tensor(a!) out) -> Tensor(a!) _linalg_inv_out_helper.functional(Tensor self, Tensor infos_lu, Tensor infos_getri) -> (Tensor, Tensor infos_lu_out, Tensor infos_getri_out) ``` [ghstack-poisoned]
…get functionalization to work for all mutable ops" This PR is pretty large, but it's motivated by the following idea: - **every** mutable operators in aten should be functionalize-able - **every** mutable operator should have a functional + out= variant, so our codegen can operate on it in a more structured way (and full out= coverage support is probably useful for mobile, for memory planning) ### The main changes - Introduce a new `SchemaKind.mutable` enum in the codegen - Update the codegen grouping logic to properly group all functional/inplace/out=/mutable variants today (and add a bunch of error checks and restrictions to tighten up the set of schemas that we allow into native_functions.yaml) - automatically generate some new `NativeFunctions` in the codegen (!!). Under certain conditions, we generate `functional` and `out=` variants of some existing aten operators - code-generate `mutable` -> `functional` kernels for any of the newly generated `functional` NativeFunction objects. - Clean up functionalization codegen, now that it can rely on the existing grouping logic - clean up LTC to only write lowerings for functional ops (we can do this now that every mutable operator has a functional equivalent. Generating all of these new `NativeFunction`'s is a pretty big change - up until now, every operator in aten was explicitly spelled out in `NativeFunctions.yaml`. This seems more ok to do now, because - we now have a `torchgen` package that you can install, and use to dynamically inspect all of the aten ops used in code generation - There are just so many functional / out= ops that are missing, and adding them all manually would be a massive amount of boilerplate A lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for. I listed out the full set of new `NativeFunctions` at the bottom of this description. It also shouldn't be too hard to add generated `foo.scratch` of out= operators on top of this, if we decide that's useful. ### Enumeration of changes / what order to look at things (1) I would recommend starting with the updated versions of `FunctionSchema.signature()` and `Arguments.signature()` in `model.py`. This is the main, core change to our operator grouping logic, that lets us always group `functional/inplace/mutable/out=` ops together; a lot of the other changes follow from it. In it, we: - Convert **mutable** (`post_self_positional` args) to returns (which come **after** any of the original returns) - drop `TensorOptions` args (this lets us properly group the existing out= factory ops) in `FunctionSchema.__post_init__()`, I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions). (2) Next, the code for generating `functional` + `out=` NativeFunctions - In `gen.py`, `add_generated_native_functions()` has the logic for deciding when to generate new `NativeFunction` objects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful of `functional` ops that don't have `out=` variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add. Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them The code that actually generates new NativeFunctions is `generate_function` in `model.py`. Given a "base function" of one `SchemaKind`, and a target `SchemaKind`, it generates a new `NativeFunction` with the target schema. For now, we only actually use it with functional / out= as the target schema. (3) Generating functional kernels in terms of their existing mutable variants. This happens in `gen_composite_functional_kernel` in `gen_functionalization.py`. I had to modify `translate()` to be able to remove const-ness when calling a mutable op from a functional op. (4) updating the functionalization codegen in a few ways: - We now have full support for all mutable -> functional op transformations. Including weird `SchemaKind.mutable` ops like `_fused_moving_avg_obs_fq_helper` -> `_fused_moving_avg_obs_fq_helper.functional`, and out= factory ops like`range.start_out` -> `range.start_step`. For `SchemaKind.mutable` ops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to update `translate()` to know that it could grab TensorOptions arguments from the `out` tensor in the calling context. - I removed the side-car mapping of mutable -> functional ops, so we now rely fully on the normal `NativeFunctionsGroup` groupings. I still ended up passing ungrouped `NativeFunctions` into the functionalization codegen for 2 reasons: (a) We need to register `CompositeImplicitAutograd` kernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions) (b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though. (5) Updating the LazyTensor codegen LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team. ### Full list of newly generated `NativeFunction` objects new functional ops count: 74 new out= ops count: 97 total new ops: 171 ``` _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_functional(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.float_functional(Tensor self, float p=0.5, *, Generator? generator=None) -> Tensor copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) embedding_renorm.functional(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor resize.out(Tensor self, int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize.functional(Tensor self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.functional(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) _mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize_as.functional(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) resize_as_sparse.functional(Tensor self, Tensor the_template) -> Tensor zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) zero.functional(Tensor self) -> Tensor sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize_and_clear.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor _coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) _coalesced.functional(Tensor self, bool coalesced) -> Tensor copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) copy_sparse_to_sparse.functional(Tensor self, Tensor src, bool non_blocking=False) -> Tensor _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_functional(Tensor self, Storage source) -> Tensor set.source_Storage_storage_offset_out(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[], *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_storage_offset_functional(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) set.source_Tensor_functional(Tensor self, Tensor source) -> Tensor set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) set.functional(Tensor self) -> Tensor masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.from_functional(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.to_functional(Tensor self, int to, *, Generator? generator=None) -> Tensor random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.functional(Tensor self, *, Generator? generator=None) -> Tensor uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) uniform.functional(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) cauchy.functional(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) log_normal.functional(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) exponential.functional(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) geometric.functional(Tensor self, float p, *, Generator? generator=None) -> Tensor normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) normal.functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor _amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () _amp_foreach_non_finite_check_and_unscale.functional(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) _amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) _amp_update_scale.functional(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) _foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_add.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_sub.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_mul.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_div.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_add.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_sub.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_mul.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_div.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_add.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_sub.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_div.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_mul.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_zero.functional(Tensor[] self) -> Tensor[] self_out _foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_exp.functional(Tensor[] self) -> Tensor[] self_out _foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sqrt.functional(Tensor[] self) -> Tensor[] self_out _foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_abs.functional(Tensor[] self) -> Tensor[] self_out _foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_acos.functional(Tensor[] self) -> Tensor[] self_out _foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_asin.functional(Tensor[] self) -> Tensor[] self_out _foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_atan.functional(Tensor[] self) -> Tensor[] self_out _foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_ceil.functional(Tensor[] self) -> Tensor[] self_out _foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cos.functional(Tensor[] self) -> Tensor[] self_out _foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cosh.functional(Tensor[] self) -> Tensor[] self_out _foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erf.functional(Tensor[] self) -> Tensor[] self_out _foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erfc.functional(Tensor[] self) -> Tensor[] self_out _foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_expm1.functional(Tensor[] self) -> Tensor[] self_out _foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_floor.functional(Tensor[] self) -> Tensor[] self_out _foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log.functional(Tensor[] self) -> Tensor[] self_out _foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log10.functional(Tensor[] self) -> Tensor[] self_out _foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log1p.functional(Tensor[] self) -> Tensor[] self_out _foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log2.functional(Tensor[] self) -> Tensor[] self_out _foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_neg.functional(Tensor[] self) -> Tensor[] self_out _foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tan.functional(Tensor[] self) -> Tensor[] self_out _foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tanh.functional(Tensor[] self) -> Tensor[] self_out _foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sin.functional(Tensor[] self) -> Tensor[] self_out _foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sinh.functional(Tensor[] self) -> Tensor[] self_out _foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_round.functional(Tensor[] self) -> Tensor[] self_out _foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_lgamma.functional(Tensor[] self) -> Tensor[] self_out _foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_frac.functional(Tensor[] self) -> Tensor[] self_out _foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_reciprocal.functional(Tensor[] self) -> Tensor[] self_out _foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sigmoid.functional(Tensor[] self) -> Tensor[] self_out _foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_trunc.functional(Tensor[] self) -> Tensor[] self_out _foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcdiv.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcmul.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcdiv.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcmul.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _linalg_inv_out_helper.out(Tensor self, Tensor(b!) infos_lu, Tensor(c!) infos_getri, *, Tensor(a!) out) -> Tensor(a!) _linalg_inv_out_helper.functional(Tensor self, Tensor infos_lu, Tensor infos_getri) -> (Tensor, Tensor infos_lu_out, Tensor infos_getri_out) ``` [ghstack-poisoned]
|
Ok, I think I've addressed the latest round of feedback - biggest change was moving all of the logic associated with gen'ing NativeFunctions into a new Test failures:
|
…get functionalization to work for all mutable ops" This PR is pretty large, but it's motivated by the following idea: - **every** mutable operators in aten should be functionalize-able - **every** mutable operator should have a functional + out= variant, so our codegen can operate on it in a more structured way (and full out= coverage support is probably useful for mobile, for memory planning) ### The main changes - Introduce a new `SchemaKind.mutable` enum in the codegen - Update the codegen grouping logic to properly group all functional/inplace/out=/mutable variants today (and add a bunch of error checks and restrictions to tighten up the set of schemas that we allow into native_functions.yaml) - automatically generate some new `NativeFunctions` in the codegen (!!). Under certain conditions, we generate `functional` and `out=` variants of some existing aten operators - code-generate `mutable` -> `functional` kernels for any of the newly generated `functional` NativeFunction objects. - Clean up functionalization codegen, now that it can rely on the existing grouping logic - clean up LTC to only write lowerings for functional ops (we can do this now that every mutable operator has a functional equivalent. Generating all of these new `NativeFunction`'s is a pretty big change - up until now, every operator in aten was explicitly spelled out in `NativeFunctions.yaml`. This seems more ok to do now, because - we now have a `torchgen` package that you can install, and use to dynamically inspect all of the aten ops used in code generation - There are just so many functional / out= ops that are missing, and adding them all manually would be a massive amount of boilerplate A lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for. I listed out the full set of new `NativeFunctions` at the bottom of this description. It also shouldn't be too hard to add generated `foo.scratch` of out= operators on top of this, if we decide that's useful. ### Enumeration of changes / what order to look at things (1) I would recommend starting with the updated versions of `FunctionSchema.signature()` and `Arguments.signature()` in `model.py`. This is the main, core change to our operator grouping logic, that lets us always group `functional/inplace/mutable/out=` ops together; a lot of the other changes follow from it. In it, we: - Convert **mutable** (`post_self_positional` args) to returns (which come **after** any of the original returns) - drop `TensorOptions` args (this lets us properly group the existing out= factory ops) in `FunctionSchema.__post_init__()`, I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions). (2) Next, the code for generating `functional` + `out=` NativeFunctions - In `gen.py`, `add_generated_native_functions()` has the logic for deciding when to generate new `NativeFunction` objects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful of `functional` ops that don't have `out=` variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add. Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them The code that actually generates new NativeFunctions is `generate_function` in `model.py`. Given a "base function" of one `SchemaKind`, and a target `SchemaKind`, it generates a new `NativeFunction` with the target schema. For now, we only actually use it with functional / out= as the target schema. (3) Generating functional kernels in terms of their existing mutable variants. This happens in `gen_composite_functional_kernel` in `gen_functionalization.py`. I had to modify `translate()` to be able to remove const-ness when calling a mutable op from a functional op. (4) updating the functionalization codegen in a few ways: - We now have full support for all mutable -> functional op transformations. Including weird `SchemaKind.mutable` ops like `_fused_moving_avg_obs_fq_helper` -> `_fused_moving_avg_obs_fq_helper.functional`, and out= factory ops like`range.start_out` -> `range.start_step`. For `SchemaKind.mutable` ops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to update `translate()` to know that it could grab TensorOptions arguments from the `out` tensor in the calling context. - I removed the side-car mapping of mutable -> functional ops, so we now rely fully on the normal `NativeFunctionsGroup` groupings. I still ended up passing ungrouped `NativeFunctions` into the functionalization codegen for 2 reasons: (a) We need to register `CompositeImplicitAutograd` kernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions) (b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though. (5) Updating the LazyTensor codegen LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team. ### Full list of newly generated `NativeFunction` objects new functional ops count: 74 new out= ops count: 97 total new ops: 171 ``` _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_functional(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.float_functional(Tensor self, float p=0.5, *, Generator? generator=None) -> Tensor copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) embedding_renorm.functional(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor resize.out(Tensor self, int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize.functional(Tensor self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.functional(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) _mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize_as.functional(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) resize_as_sparse.functional(Tensor self, Tensor the_template) -> Tensor zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) zero.functional(Tensor self) -> Tensor sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize_and_clear.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor _coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) _coalesced.functional(Tensor self, bool coalesced) -> Tensor copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) copy_sparse_to_sparse.functional(Tensor self, Tensor src, bool non_blocking=False) -> Tensor _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_functional(Tensor self, Storage source) -> Tensor set.source_Storage_storage_offset_out(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[], *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_storage_offset_functional(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) set.source_Tensor_functional(Tensor self, Tensor source) -> Tensor set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) set.functional(Tensor self) -> Tensor masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.from_functional(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.to_functional(Tensor self, int to, *, Generator? generator=None) -> Tensor random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.functional(Tensor self, *, Generator? generator=None) -> Tensor uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) uniform.functional(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) cauchy.functional(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) log_normal.functional(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) exponential.functional(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) geometric.functional(Tensor self, float p, *, Generator? generator=None) -> Tensor normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) normal.functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor _amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () _amp_foreach_non_finite_check_and_unscale.functional(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) _amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) _amp_update_scale.functional(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) _foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_add.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_sub.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_mul.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_div.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_add.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_sub.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_mul.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_div.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_add.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_sub.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_div.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_mul.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_zero.functional(Tensor[] self) -> Tensor[] self_out _foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_exp.functional(Tensor[] self) -> Tensor[] self_out _foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sqrt.functional(Tensor[] self) -> Tensor[] self_out _foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_abs.functional(Tensor[] self) -> Tensor[] self_out _foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_acos.functional(Tensor[] self) -> Tensor[] self_out _foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_asin.functional(Tensor[] self) -> Tensor[] self_out _foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_atan.functional(Tensor[] self) -> Tensor[] self_out _foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_ceil.functional(Tensor[] self) -> Tensor[] self_out _foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cos.functional(Tensor[] self) -> Tensor[] self_out _foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cosh.functional(Tensor[] self) -> Tensor[] self_out _foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erf.functional(Tensor[] self) -> Tensor[] self_out _foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erfc.functional(Tensor[] self) -> Tensor[] self_out _foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_expm1.functional(Tensor[] self) -> Tensor[] self_out _foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_floor.functional(Tensor[] self) -> Tensor[] self_out _foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log.functional(Tensor[] self) -> Tensor[] self_out _foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log10.functional(Tensor[] self) -> Tensor[] self_out _foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log1p.functional(Tensor[] self) -> Tensor[] self_out _foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log2.functional(Tensor[] self) -> Tensor[] self_out _foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_neg.functional(Tensor[] self) -> Tensor[] self_out _foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tan.functional(Tensor[] self) -> Tensor[] self_out _foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tanh.functional(Tensor[] self) -> Tensor[] self_out _foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sin.functional(Tensor[] self) -> Tensor[] self_out _foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sinh.functional(Tensor[] self) -> Tensor[] self_out _foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_round.functional(Tensor[] self) -> Tensor[] self_out _foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_lgamma.functional(Tensor[] self) -> Tensor[] self_out _foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_frac.functional(Tensor[] self) -> Tensor[] self_out _foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_reciprocal.functional(Tensor[] self) -> Tensor[] self_out _foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sigmoid.functional(Tensor[] self) -> Tensor[] self_out _foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_trunc.functional(Tensor[] self) -> Tensor[] self_out _foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcdiv.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcmul.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcdiv.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcmul.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _linalg_inv_out_helper.out(Tensor self, Tensor(b!) infos_lu, Tensor(c!) infos_getri, *, Tensor(a!) out) -> Tensor(a!) _linalg_inv_out_helper.functional(Tensor self, Tensor infos_lu, Tensor infos_getri) -> (Tensor, Tensor infos_lu_out, Tensor infos_getri_out) ``` [ghstack-poisoned]
|
@pytorchbot merge this please |
|
Merge failed due to Command Raised by https://github.com/pytorch/pytorch/actions/runs/2349137049 |
…get functionalization to work for all mutable ops" This PR is pretty large, but it's motivated by the following idea: - **every** mutable operators in aten should be functionalize-able - **every** mutable operator should have a functional + out= variant, so our codegen can operate on it in a more structured way (and full out= coverage support is probably useful for mobile, for memory planning) ### The main changes - Introduce a new `SchemaKind.mutable` enum in the codegen - Update the codegen grouping logic to properly group all functional/inplace/out=/mutable variants today (and add a bunch of error checks and restrictions to tighten up the set of schemas that we allow into native_functions.yaml) - automatically generate some new `NativeFunctions` in the codegen (!!). Under certain conditions, we generate `functional` and `out=` variants of some existing aten operators - code-generate `mutable` -> `functional` kernels for any of the newly generated `functional` NativeFunction objects. - Clean up functionalization codegen, now that it can rely on the existing grouping logic - clean up LTC to only write lowerings for functional ops (we can do this now that every mutable operator has a functional equivalent. Generating all of these new `NativeFunction`'s is a pretty big change - up until now, every operator in aten was explicitly spelled out in `NativeFunctions.yaml`. This seems more ok to do now, because - we now have a `torchgen` package that you can install, and use to dynamically inspect all of the aten ops used in code generation - There are just so many functional / out= ops that are missing, and adding them all manually would be a massive amount of boilerplate A lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for. I listed out the full set of new `NativeFunctions` at the bottom of this description. It also shouldn't be too hard to add generated `foo.scratch` of out= operators on top of this, if we decide that's useful. ### Enumeration of changes / what order to look at things (1) I would recommend starting with the updated versions of `FunctionSchema.signature()` and `Arguments.signature()` in `model.py`. This is the main, core change to our operator grouping logic, that lets us always group `functional/inplace/mutable/out=` ops together; a lot of the other changes follow from it. In it, we: - Convert **mutable** (`post_self_positional` args) to returns (which come **after** any of the original returns) - drop `TensorOptions` args (this lets us properly group the existing out= factory ops) in `FunctionSchema.__post_init__()`, I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions). (2) Next, the code for generating `functional` + `out=` NativeFunctions - In `gen.py`, `add_generated_native_functions()` has the logic for deciding when to generate new `NativeFunction` objects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful of `functional` ops that don't have `out=` variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add. Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them The code that actually generates new NativeFunctions is `generate_function` in `model.py`. Given a "base function" of one `SchemaKind`, and a target `SchemaKind`, it generates a new `NativeFunction` with the target schema. For now, we only actually use it with functional / out= as the target schema. (3) Generating functional kernels in terms of their existing mutable variants. This happens in `gen_composite_functional_kernel` in `gen_functionalization.py`. I had to modify `translate()` to be able to remove const-ness when calling a mutable op from a functional op. (4) updating the functionalization codegen in a few ways: - We now have full support for all mutable -> functional op transformations. Including weird `SchemaKind.mutable` ops like `_fused_moving_avg_obs_fq_helper` -> `_fused_moving_avg_obs_fq_helper.functional`, and out= factory ops like`range.start_out` -> `range.start_step`. For `SchemaKind.mutable` ops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to update `translate()` to know that it could grab TensorOptions arguments from the `out` tensor in the calling context. - I removed the side-car mapping of mutable -> functional ops, so we now rely fully on the normal `NativeFunctionsGroup` groupings. I still ended up passing ungrouped `NativeFunctions` into the functionalization codegen for 2 reasons: (a) We need to register `CompositeImplicitAutograd` kernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions) (b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though. (5) Updating the LazyTensor codegen LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team. ### Full list of newly generated `NativeFunction` objects new functional ops count: 74 new out= ops count: 97 total new ops: 171 ``` _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_functional(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.float_functional(Tensor self, float p=0.5, *, Generator? generator=None) -> Tensor copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) embedding_renorm.functional(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor resize.out(Tensor self, int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize.functional(Tensor self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.functional(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) _mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize_as.functional(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) resize_as_sparse.functional(Tensor self, Tensor the_template) -> Tensor zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) zero.functional(Tensor self) -> Tensor sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize_and_clear.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor _coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) _coalesced.functional(Tensor self, bool coalesced) -> Tensor copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) copy_sparse_to_sparse.functional(Tensor self, Tensor src, bool non_blocking=False) -> Tensor _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_functional(Tensor self, Storage source) -> Tensor set.source_Storage_storage_offset_out(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[], *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_storage_offset_functional(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) set.source_Tensor_functional(Tensor self, Tensor source) -> Tensor set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) set.functional(Tensor self) -> Tensor masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.from_functional(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.to_functional(Tensor self, int to, *, Generator? generator=None) -> Tensor random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.functional(Tensor self, *, Generator? generator=None) -> Tensor uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) uniform.functional(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) cauchy.functional(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) log_normal.functional(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) exponential.functional(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) geometric.functional(Tensor self, float p, *, Generator? generator=None) -> Tensor normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) normal.functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor _amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () _amp_foreach_non_finite_check_and_unscale.functional(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) _amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) _amp_update_scale.functional(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) _foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_add.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_sub.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_mul.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_div.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_add.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_sub.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_mul.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_div.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_add.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_sub.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_div.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_mul.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_zero.functional(Tensor[] self) -> Tensor[] self_out _foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_exp.functional(Tensor[] self) -> Tensor[] self_out _foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sqrt.functional(Tensor[] self) -> Tensor[] self_out _foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_abs.functional(Tensor[] self) -> Tensor[] self_out _foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_acos.functional(Tensor[] self) -> Tensor[] self_out _foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_asin.functional(Tensor[] self) -> Tensor[] self_out _foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_atan.functional(Tensor[] self) -> Tensor[] self_out _foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_ceil.functional(Tensor[] self) -> Tensor[] self_out _foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cos.functional(Tensor[] self) -> Tensor[] self_out _foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cosh.functional(Tensor[] self) -> Tensor[] self_out _foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erf.functional(Tensor[] self) -> Tensor[] self_out _foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erfc.functional(Tensor[] self) -> Tensor[] self_out _foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_expm1.functional(Tensor[] self) -> Tensor[] self_out _foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_floor.functional(Tensor[] self) -> Tensor[] self_out _foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log.functional(Tensor[] self) -> Tensor[] self_out _foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log10.functional(Tensor[] self) -> Tensor[] self_out _foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log1p.functional(Tensor[] self) -> Tensor[] self_out _foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log2.functional(Tensor[] self) -> Tensor[] self_out _foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_neg.functional(Tensor[] self) -> Tensor[] self_out _foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tan.functional(Tensor[] self) -> Tensor[] self_out _foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tanh.functional(Tensor[] self) -> Tensor[] self_out _foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sin.functional(Tensor[] self) -> Tensor[] self_out _foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sinh.functional(Tensor[] self) -> Tensor[] self_out _foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_round.functional(Tensor[] self) -> Tensor[] self_out _foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_lgamma.functional(Tensor[] self) -> Tensor[] self_out _foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_frac.functional(Tensor[] self) -> Tensor[] self_out _foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_reciprocal.functional(Tensor[] self) -> Tensor[] self_out _foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sigmoid.functional(Tensor[] self) -> Tensor[] self_out _foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_trunc.functional(Tensor[] self) -> Tensor[] self_out _foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcdiv.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcmul.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcdiv.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcmul.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _linalg_inv_out_helper.out(Tensor self, Tensor(b!) infos_lu, Tensor(c!) infos_getri, *, Tensor(a!) out) -> Tensor(a!) _linalg_inv_out_helper.functional(Tensor self, Tensor infos_lu, Tensor infos_getri) -> (Tensor, Tensor infos_lu_out, Tensor infos_getri_out) ``` [ghstack-poisoned]
…get functionalization to work for all mutable ops" This PR is pretty large, but it's motivated by the following idea: - **every** mutable operators in aten should be functionalize-able - **every** mutable operator should have a functional + out= variant, so our codegen can operate on it in a more structured way (and full out= coverage support is probably useful for mobile, for memory planning) ### The main changes - Introduce a new `SchemaKind.mutable` enum in the codegen - Update the codegen grouping logic to properly group all functional/inplace/out=/mutable variants today (and add a bunch of error checks and restrictions to tighten up the set of schemas that we allow into native_functions.yaml) - automatically generate some new `NativeFunctions` in the codegen (!!). Under certain conditions, we generate `functional` and `out=` variants of some existing aten operators - code-generate `mutable` -> `functional` kernels for any of the newly generated `functional` NativeFunction objects. - Clean up functionalization codegen, now that it can rely on the existing grouping logic - clean up LTC to only write lowerings for functional ops (we can do this now that every mutable operator has a functional equivalent. Generating all of these new `NativeFunction`'s is a pretty big change - up until now, every operator in aten was explicitly spelled out in `NativeFunctions.yaml`. This seems more ok to do now, because - we now have a `torchgen` package that you can install, and use to dynamically inspect all of the aten ops used in code generation - There are just so many functional / out= ops that are missing, and adding them all manually would be a massive amount of boilerplate A lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for. I listed out the full set of new `NativeFunctions` at the bottom of this description. It also shouldn't be too hard to add generated `foo.scratch` of out= operators on top of this, if we decide that's useful. ### Enumeration of changes / what order to look at things (1) I would recommend starting with the updated versions of `FunctionSchema.signature()` and `Arguments.signature()` in `model.py`. This is the main, core change to our operator grouping logic, that lets us always group `functional/inplace/mutable/out=` ops together; a lot of the other changes follow from it. In it, we: - Convert **mutable** (`post_self_positional` args) to returns (which come **after** any of the original returns) - drop `TensorOptions` args (this lets us properly group the existing out= factory ops) in `FunctionSchema.__post_init__()`, I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions). (2) Next, the code for generating `functional` + `out=` NativeFunctions - In `gen.py`, `add_generated_native_functions()` has the logic for deciding when to generate new `NativeFunction` objects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful of `functional` ops that don't have `out=` variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add. Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them The code that actually generates new NativeFunctions is `generate_function` in `model.py`. Given a "base function" of one `SchemaKind`, and a target `SchemaKind`, it generates a new `NativeFunction` with the target schema. For now, we only actually use it with functional / out= as the target schema. (3) Generating functional kernels in terms of their existing mutable variants. This happens in `gen_composite_functional_kernel` in `gen_functionalization.py`. I had to modify `translate()` to be able to remove const-ness when calling a mutable op from a functional op. (4) updating the functionalization codegen in a few ways: - We now have full support for all mutable -> functional op transformations. Including weird `SchemaKind.mutable` ops like `_fused_moving_avg_obs_fq_helper` -> `_fused_moving_avg_obs_fq_helper.functional`, and out= factory ops like`range.start_out` -> `range.start_step`. For `SchemaKind.mutable` ops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to update `translate()` to know that it could grab TensorOptions arguments from the `out` tensor in the calling context. - I removed the side-car mapping of mutable -> functional ops, so we now rely fully on the normal `NativeFunctionsGroup` groupings. I still ended up passing ungrouped `NativeFunctions` into the functionalization codegen for 2 reasons: (a) We need to register `CompositeImplicitAutograd` kernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions) (b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though. (5) Updating the LazyTensor codegen LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team. ### Full list of newly generated `NativeFunction` objects new functional ops count: 74 new out= ops count: 97 total new ops: 171 ``` _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_functional(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.float_functional(Tensor self, float p=0.5, *, Generator? generator=None) -> Tensor copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) embedding_renorm.functional(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor resize.out(Tensor self, int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize.functional(Tensor self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.functional(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) _mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize_as.functional(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) resize_as_sparse.functional(Tensor self, Tensor the_template) -> Tensor zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) zero.functional(Tensor self) -> Tensor sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize_and_clear.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor _coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) _coalesced.functional(Tensor self, bool coalesced) -> Tensor copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) copy_sparse_to_sparse.functional(Tensor self, Tensor src, bool non_blocking=False) -> Tensor _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_functional(Tensor self, Storage source) -> Tensor set.source_Storage_storage_offset_out(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[], *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_storage_offset_functional(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) set.source_Tensor_functional(Tensor self, Tensor source) -> Tensor set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) set.functional(Tensor self) -> Tensor masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.from_functional(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.to_functional(Tensor self, int to, *, Generator? generator=None) -> Tensor random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.functional(Tensor self, *, Generator? generator=None) -> Tensor uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) uniform.functional(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) cauchy.functional(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) log_normal.functional(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) exponential.functional(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) geometric.functional(Tensor self, float p, *, Generator? generator=None) -> Tensor normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) normal.functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor _amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () _amp_foreach_non_finite_check_and_unscale.functional(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) _amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) _amp_update_scale.functional(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) _foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_add.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_sub.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_mul.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_div.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_add.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_sub.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_mul.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_div.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_add.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_sub.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_div.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_mul.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_zero.functional(Tensor[] self) -> Tensor[] self_out _foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_exp.functional(Tensor[] self) -> Tensor[] self_out _foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sqrt.functional(Tensor[] self) -> Tensor[] self_out _foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_abs.functional(Tensor[] self) -> Tensor[] self_out _foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_acos.functional(Tensor[] self) -> Tensor[] self_out _foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_asin.functional(Tensor[] self) -> Tensor[] self_out _foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_atan.functional(Tensor[] self) -> Tensor[] self_out _foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_ceil.functional(Tensor[] self) -> Tensor[] self_out _foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cos.functional(Tensor[] self) -> Tensor[] self_out _foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cosh.functional(Tensor[] self) -> Tensor[] self_out _foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erf.functional(Tensor[] self) -> Tensor[] self_out _foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erfc.functional(Tensor[] self) -> Tensor[] self_out _foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_expm1.functional(Tensor[] self) -> Tensor[] self_out _foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_floor.functional(Tensor[] self) -> Tensor[] self_out _foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log.functional(Tensor[] self) -> Tensor[] self_out _foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log10.functional(Tensor[] self) -> Tensor[] self_out _foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log1p.functional(Tensor[] self) -> Tensor[] self_out _foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log2.functional(Tensor[] self) -> Tensor[] self_out _foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_neg.functional(Tensor[] self) -> Tensor[] self_out _foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tan.functional(Tensor[] self) -> Tensor[] self_out _foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tanh.functional(Tensor[] self) -> Tensor[] self_out _foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sin.functional(Tensor[] self) -> Tensor[] self_out _foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sinh.functional(Tensor[] self) -> Tensor[] self_out _foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_round.functional(Tensor[] self) -> Tensor[] self_out _foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_lgamma.functional(Tensor[] self) -> Tensor[] self_out _foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_frac.functional(Tensor[] self) -> Tensor[] self_out _foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_reciprocal.functional(Tensor[] self) -> Tensor[] self_out _foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sigmoid.functional(Tensor[] self) -> Tensor[] self_out _foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_trunc.functional(Tensor[] self) -> Tensor[] self_out _foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcdiv.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcmul.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcdiv.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcmul.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _linalg_inv_out_helper.out(Tensor self, Tensor(b!) infos_lu, Tensor(c!) infos_getri, *, Tensor(a!) out) -> Tensor(a!) _linalg_inv_out_helper.functional(Tensor self, Tensor infos_lu, Tensor infos_getri) -> (Tensor, Tensor infos_lu_out, Tensor infos_getri_out) ``` [ghstack-poisoned]
…get functionalization to work for all mutable ops" This PR is pretty large, but it's motivated by the following idea: - **every** mutable operators in aten should be functionalize-able - **every** mutable operator should have a functional + out= variant, so our codegen can operate on it in a more structured way (and full out= coverage support is probably useful for mobile, for memory planning) ### The main changes - Introduce a new `SchemaKind.mutable` enum in the codegen - Update the codegen grouping logic to properly group all functional/inplace/out=/mutable variants today (and add a bunch of error checks and restrictions to tighten up the set of schemas that we allow into native_functions.yaml) - automatically generate some new `NativeFunctions` in the codegen (!!). Under certain conditions, we generate `functional` and `out=` variants of some existing aten operators - code-generate `mutable` -> `functional` kernels for any of the newly generated `functional` NativeFunction objects. - Clean up functionalization codegen, now that it can rely on the existing grouping logic - clean up LTC to only write lowerings for functional ops (we can do this now that every mutable operator has a functional equivalent. Generating all of these new `NativeFunction`'s is a pretty big change - up until now, every operator in aten was explicitly spelled out in `NativeFunctions.yaml`. This seems more ok to do now, because - we now have a `torchgen` package that you can install, and use to dynamically inspect all of the aten ops used in code generation - There are just so many functional / out= ops that are missing, and adding them all manually would be a massive amount of boilerplate A lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for. I listed out the full set of new `NativeFunctions` at the bottom of this description. It also shouldn't be too hard to add generated `foo.scratch` of out= operators on top of this, if we decide that's useful. ### Enumeration of changes / what order to look at things (1) I would recommend starting with the updated versions of `FunctionSchema.signature()` and `Arguments.signature()` in `model.py`. This is the main, core change to our operator grouping logic, that lets us always group `functional/inplace/mutable/out=` ops together; a lot of the other changes follow from it. In it, we: - Convert **mutable** (`post_self_positional` args) to returns (which come **after** any of the original returns) - drop `TensorOptions` args (this lets us properly group the existing out= factory ops) in `FunctionSchema.__post_init__()`, I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions). (2) Next, the code for generating `functional` + `out=` NativeFunctions - In `gen.py`, `add_generated_native_functions()` has the logic for deciding when to generate new `NativeFunction` objects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful of `functional` ops that don't have `out=` variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add. Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them The code that actually generates new NativeFunctions is `generate_function` in `model.py`. Given a "base function" of one `SchemaKind`, and a target `SchemaKind`, it generates a new `NativeFunction` with the target schema. For now, we only actually use it with functional / out= as the target schema. (3) Generating functional kernels in terms of their existing mutable variants. This happens in `gen_composite_functional_kernel` in `gen_functionalization.py`. I had to modify `translate()` to be able to remove const-ness when calling a mutable op from a functional op. (4) updating the functionalization codegen in a few ways: - We now have full support for all mutable -> functional op transformations. Including weird `SchemaKind.mutable` ops like `_fused_moving_avg_obs_fq_helper` -> `_fused_moving_avg_obs_fq_helper.functional`, and out= factory ops like`range.start_out` -> `range.start_step`. For `SchemaKind.mutable` ops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to update `translate()` to know that it could grab TensorOptions arguments from the `out` tensor in the calling context. - I removed the side-car mapping of mutable -> functional ops, so we now rely fully on the normal `NativeFunctionsGroup` groupings. I still ended up passing ungrouped `NativeFunctions` into the functionalization codegen for 2 reasons: (a) We need to register `CompositeImplicitAutograd` kernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions) (b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though. (5) Updating the LazyTensor codegen LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team. ### Full list of newly generated `NativeFunction` objects new functional ops count: 74 new out= ops count: 97 total new ops: 171 ``` _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.Tensor_functional(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) bernoulli.float_functional(Tensor self, float p=0.5, *, Generator? generator=None) -> Tensor copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) embedding_renorm.functional(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor resize.out(Tensor self, int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize.functional(Tensor self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) _index_put_impl.functional(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) _mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) resize_as.functional(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) resize_as_sparse.functional(Tensor self, Tensor the_template) -> Tensor zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) zero.functional(Tensor self) -> Tensor sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) sparse_resize_and_clear.functional(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor _coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) _coalesced.functional(Tensor self, bool coalesced) -> Tensor copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) copy_sparse_to_sparse.functional(Tensor self, Tensor src, bool non_blocking=False) -> Tensor _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_functional(Tensor self, Storage source) -> Tensor set.source_Storage_storage_offset_out(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[], *, Tensor(a!) out) -> Tensor(a!) set.source_Storage_storage_offset_functional(Tensor self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) set.source_Tensor_functional(Tensor self, Tensor source) -> Tensor set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) set.functional(Tensor self) -> Tensor masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) __rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.from_functional(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.to_functional(Tensor self, int to, *, Generator? generator=None) -> Tensor random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) random.functional(Tensor self, *, Generator? generator=None) -> Tensor uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) uniform.functional(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) cauchy.functional(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) log_normal.functional(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) exponential.functional(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) geometric.functional(Tensor self, float p, *, Generator? generator=None) -> Tensor normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) normal.functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor _amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () _amp_foreach_non_finite_check_and_unscale.functional(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) _amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) _amp_update_scale.functional(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) _foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_add.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_sub.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_mul.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () _foreach_div.Scalar_functional(Tensor[] self, Scalar scalar) -> Tensor[] self_out _foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_add.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () _foreach_sub.List_functional(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] self_out _foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_mul.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () _foreach_div.List_functional(Tensor[] self, Tensor[] other) -> Tensor[] self_out _foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_add.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_sub.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_div.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_mul.ScalarList_functional(Tensor[] self, Scalar[] scalars) -> Tensor[] self_out _foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_zero.functional(Tensor[] self) -> Tensor[] self_out _foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_exp.functional(Tensor[] self) -> Tensor[] self_out _foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sqrt.functional(Tensor[] self) -> Tensor[] self_out _foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_abs.functional(Tensor[] self) -> Tensor[] self_out _foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_acos.functional(Tensor[] self) -> Tensor[] self_out _foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_asin.functional(Tensor[] self) -> Tensor[] self_out _foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_atan.functional(Tensor[] self) -> Tensor[] self_out _foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_ceil.functional(Tensor[] self) -> Tensor[] self_out _foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cos.functional(Tensor[] self) -> Tensor[] self_out _foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_cosh.functional(Tensor[] self) -> Tensor[] self_out _foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erf.functional(Tensor[] self) -> Tensor[] self_out _foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_erfc.functional(Tensor[] self) -> Tensor[] self_out _foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_expm1.functional(Tensor[] self) -> Tensor[] self_out _foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_floor.functional(Tensor[] self) -> Tensor[] self_out _foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log.functional(Tensor[] self) -> Tensor[] self_out _foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log10.functional(Tensor[] self) -> Tensor[] self_out _foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log1p.functional(Tensor[] self) -> Tensor[] self_out _foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_log2.functional(Tensor[] self) -> Tensor[] self_out _foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_neg.functional(Tensor[] self) -> Tensor[] self_out _foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tan.functional(Tensor[] self) -> Tensor[] self_out _foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_tanh.functional(Tensor[] self) -> Tensor[] self_out _foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sin.functional(Tensor[] self) -> Tensor[] self_out _foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sinh.functional(Tensor[] self) -> Tensor[] self_out _foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_round.functional(Tensor[] self) -> Tensor[] self_out _foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_lgamma.functional(Tensor[] self) -> Tensor[] self_out _foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_frac.functional(Tensor[] self) -> Tensor[] self_out _foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_reciprocal.functional(Tensor[] self) -> Tensor[] self_out _foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_sigmoid.functional(Tensor[] self) -> Tensor[] self_out _foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () _foreach_trunc.functional(Tensor[] self) -> Tensor[] self_out _foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcdiv.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () _foreach_addcmul.Scalar_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] self_out _foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcdiv.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () _foreach_addcmul.ScalarList_functional(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] self_out _linalg_inv_out_helper.out(Tensor self, Tensor(b!) infos_lu, Tensor(c!) infos_getri, *, Tensor(a!) out) -> Tensor(a!) _linalg_inv_out_helper.functional(Tensor self, Tensor infos_lu, Tensor infos_getri) -> (Tensor, Tensor infos_lu_out, Tensor infos_getri_out) ``` [ghstack-poisoned]
|
@pytorchbot merge this please |
|
Hey @bdhirsh. |
…#76320) Summary: Pull Request resolved: #76320 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/0161e9eb00eeacb54389309a6b53f2c97b655921 Reviewed By: seemethere Differential Revision: D36537697 Pulled By: bdhirsh fbshipit-source-id: a9cf8dfb87f13c5be0db4dcf825f378bb942227a
This PR is pretty large, but it's motivated by the following idea:
The main changes
SchemaKind.mutableenum in the codegenNativeFunctionsin the codegen (!!). Under certain conditions, we generatefunctionalandout=variants of some existing aten operatorsmutable->functionalkernels for any of the newly generatedfunctionalNativeFunction objects.Generating all of these new
NativeFunction's is a pretty big change - up until now, every operator in aten was explicitly spelled out inNativeFunctions.yaml. This seems more ok to do now, becausetorchgenpackage that you can install, and use to dynamically inspect all of the aten ops used in code generationA lot of the work in this PR involved figuring out why certain operators were/were not getting grouped properly, and classifying edge case op schemas that we should fix, vs. acceptable operators that we should update the grouping logic to account for.
I listed out the full set of new
NativeFunctionsat the bottom of this description.It also shouldn't be too hard to add generated
foo.scratchof out= operators on top of this, if we decide that's useful.Enumeration of changes / what order to look at things
(1) I would recommend starting with the updated versions of
FunctionSchema.signature()andArguments.signature()inmodel.py. This is the main, core change to our operator grouping logic, that lets us always groupfunctional/inplace/mutable/out=ops together; a lot of the other changes follow from it. In it, we:post_self_positionalargs) to returns (which come after any of the original returns)TensorOptionsargs (this lets us properly group the existing out= factory ops)in
FunctionSchema.__post_init__(), I added a bunch of new restrictions on what kind of aliasing guarantees we can assume about newly added schemas. This made it much easier for me to reason about the grouping logic, and I'm hoping they aren't too restrictive (since none of the restrictions broke any existing NativeFunctions).(2) Next, the code for generating
functional+out=NativeFunctionsgen.py,add_generated_native_functions()has the logic for deciding when to generate newNativeFunctionobjects. For now, we only generate anything for mutable, non-composite ops that are missing a functional/out= variant. We could probably generate stuff for composite ops, but that isn't really necessary for backends/tracers, since we can rely on the decompositions. There are also a handful offunctionalops that don't haveout=variants; I didn't add them in this PR because they're not important to functionalization, but they would be pretty easy to add.Note: there were a total of 5 operators today that are mutable, and don't "work" with the new grouping logic. In all cases, it's because there are some issues with their schemas that would be BC-breaking to fix (all called out in the code comments). I added them to an allow-list, and long term I think we can either fix their schemas, or manually write functionalization kernels for them
The code that actually generates new NativeFunctions is
generate_functioninmodel.py. Given a "base function" of oneSchemaKind, and a targetSchemaKind, it generates a newNativeFunctionwith the target schema. For now, we only actually use it with functional / out= as the target schema.(3) Generating functional kernels in terms of their existing mutable variants. This happens in
gen_composite_functional_kernelingen_functionalization.py. I had to modifytranslate()to be able to remove const-ness when calling a mutable op from a functional op.(4) updating the functionalization codegen in a few ways:
SchemaKind.mutableops like_fused_moving_avg_obs_fq_helper->_fused_moving_avg_obs_fq_helper.functional, and out= factory ops likerange.start_out->range.start_step. ForSchemaKind.mutableops, I had the codegen needs to know that mutable positional args are converted into returns in the functional schema. For out= factory ops, I had to updatetranslate()to know that it could grab TensorOptions arguments from theouttensor in the calling context.NativeFunctionsGroupgroupings. I still ended up passing ungroupedNativeFunctionsinto the functionalization codegen for 2 reasons:(a) We need to register
CompositeImplicitAutogradkernels directly to functionalization, even if they were ungrouped (we could in theory un-wind this if/when we eventually get a dispatch key dedicated to decompositions)(b) I defensively error if functionalization ever encounters a non-grouped, mutable operator. I could also probably just move that error check outside of the functionalization codegen though.
(5) Updating the LazyTensor codegen
LTC has some special logic to handle mutable ops that it lowers directly. I ended up breaking it as part of this change. Instead of debugging what broke, I figured it would be better long-term to just get LTC to only lower functional operators, and remove a bunch of the special handling for mutable operators. I'll probably need to run these changes by the LTC team.
Full list of newly generated
NativeFunctionobjectsnew functional ops count: 74
new out= ops count: 97
total new ops: 171
Stack from ghstack: