Optional dtype for reduce functions#15133
Optional dtype for reduce functions#15133tugrulates wants to merge 22 commits intopytorch:masterfrom tugrulates:optional-scalartype
Conversation
Summary: For #6593 and #9515 This completes the support for optional<ScalarType> in native, JIT and autograd. Note: Mostly following the existing implementation for optional<Scalar> that was added in #12582. This PR introduces a way to make functions accept an optional dtype and it will unblock #9515 by allowing the `dtype` param for type promotion interface: ``` func: name(inputs, *, ScalarType? dtype=None, Casting casting=same_kind) ``` An alternative approach could have been using `ScalarType::Undefined` for the same purpose but without optional, though it would have been a bit hacky. ``` func: name(inputs, *, ScalarType dtype=Undefined, Casting casting=same_kind) ``` Here's an example use of this in action: 971f69e There are already a bunch of native functions that were getting optional `dtype` through function overloading. #15133 is the attempt to migrate all of those. I will send those changes separately after this since some functions (e.g. sum) need quite a bit of change in the codebase. See the commits over there. Pull Request resolved: #15154 Differential Revision: D13457760 Pulled By: tugrulates fbshipit-source-id: 706134f0bd578683edd416b96329b49a1ba8ab48
| d_intermediates = [d_i for intermediates_batch in group(intermediates, shard_size) | ||
| for d_i in torch.autograd.grad(loss, intermediates_batch)] | ||
| for d_i in torch.autograd.grad(loss, intermediates_batch, | ||
| retain_graph=True)] |
There was a problem hiding this comment.
This is because the graph has one more depth due to type cast.
There was a problem hiding this comment.
as above, if we can avoid the casts and avoid doing this that seems ideal.
There was a problem hiding this comment.
The type cast will be no-op when dtype is not passed, but still adds another node to the graph. To avoid it altogether, we'll need to provide an overload to the native function without the dtype kwarg.
The behavior that is surprising is that retain_graph not being needed for the simplest backwards cases. The documentation says the grad graph will be freed, yet this test still manages to use it more than once. This is not a problem for just this PR. This quirk will apply with any change to any derivative, when the node count goes from one to two.
There was a problem hiding this comment.
it's the addition of the None node that causes you to need retain_graph?
|
This is now ready to review. |
torch/csrc/jit/autodiff.cpp
Outdated
| }); | ||
| return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr}; | ||
|
|
||
| } else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) { |
There was a problem hiding this comment.
This is already working fine through dispatch and derivatives.yaml. Rather than updating here, simply dropping the override.
There was a problem hiding this comment.
Nm, this seems to be used in symbolic. diff. Will bring back.
| variants: function, method | ||
|
|
||
| - func: softmax(Tensor self, int64_t dim) -> Tensor | ||
| - func: softmax(Tensor self, int64_t dim, *, ScalarType? dtype=None) -> Tensor |
There was a problem hiding this comment.
For softmax and log_softmax, why do we want to make it a kw only argument instead of simply merging them together? It is indeed a BC change and I don't know if we want to do that or not.
There was a problem hiding this comment.
It's mainly for the sake of consistency with other ops.
Some pro arguments are:
- These two are undocumented functions. The public interface is through
torch.nn.functional.[log_]softmaxwhich already havedtypeas a kwarg. - kwarg and non-kwarg forms are the same cpp signature so it's tricky to make them live side-by-side for a deprecation period.
- If this change is ever to be made, it's better to do it sooner than later.
- It's a better baseline ever if more kwargs are added across multiple ops (
casting,whereetc.). - The fix is trivial if any code is broken by this.
There was a problem hiding this comment.
sounds good, as long as it only expose from nn.functional I think would be fine :)
| %14 : Tensor = aten::gt(%13, %11) | ||
| %15 : bool = prim::Bool(%14) | ||
| %16 : Tensor, %17 : Tensor, %a : Tensor, %19 : Tensor, %20 : Tensor = prim::Loop(%7, %15, %8, %9, %a.1_data, %a.1_mask, %a.1_dims) | ||
| %13 : int? = prim::None() |
There was a problem hiding this comment.
these changes are pretty noisy and I'm assuming these tests already existed (so your casts are all no-ops). Can we just avoid actually calling the casts if the casts are no-op to preserve these? Is that possible?
There was a problem hiding this comment.
This is not calling a cast, but adding a prim::None to sum symbol in line 15 below.
This is because this PR is removing the overloads without the dtypes, which is kind of the purpose of it.
Since the generated graphs include the default values from the schema, this will keep happening with every new kwarg added to ops. Two ways to fix this:
- Keep overloads in
native_functions.yamland keep adding them with every new arg. (makes this PR unneeded) - Make JIT apply defaults at execution time. Maybe this behavior might be confined to kwargs only. (much larger scope than what this PR trying to achieve)
| d_intermediates = [d_i for intermediates_batch in group(intermediates, shard_size) | ||
| for d_i in torch.autograd.grad(loss, intermediates_batch)] | ||
| for d_i in torch.autograd.grad(loss, intermediates_batch, | ||
| retain_graph=True)] |
There was a problem hiding this comment.
as above, if we can avoid the casts and avoid doing this that seems ideal.
aten/src/ATen/native/ReduceOps.cpp
Outdated
| Tensor prod(const Tensor &self, optional<ScalarType> dtype) { | ||
| Tensor result; | ||
| native::prod_out(result, self, dim, keepdim, dtype); | ||
| return at::native::prod_impl(result, self, {}, false, dtype); |
There was a problem hiding this comment.
you have two returns here?
aten/src/ATen/native/ReduceOps.cpp
Outdated
| Tensor prod(const Tensor &self, ScalarType dtype) { | ||
| return at::native::prod(self, {}, false, optional<ScalarType>(dtype)); | ||
| } | ||
| // \ALL REDUCE ################################################################ |
There was a problem hiding this comment.
these comments don't seem relevant anymore.
There was a problem hiding this comment.
There was an attempt at some point but it seems to have been broken. Let me reorder by primary op.
aten/src/ATen/native/ReduceOps.cpp
Outdated
|
|
||
| static Tensor& prod_out(Tensor& result, const Tensor& self, IntList dim, | ||
| bool keepdim, optional<ScalarType> opt_dtype) { | ||
| static Tensor& prod_impl(Tensor& result, const Tensor& self, IntList dim, |
There was a problem hiding this comment.
I don't really understand the organization here. It's not grouped by function name, which makes it difficult to figure out how a specific function is implemented. It's also not really grouped by implementation type, i.e. some functions have fn_impl, some fn_out, etc. I'd just go back to function name grouping.
There was a problem hiding this comment.
Why is there a prod_impl but not a sum_impl? What's the distinction?
There was a problem hiding this comment.
Because the signatures are different:
prod(Tensor self, int64_t dim, ...sum(Tensor self, IntList[1] dim, ...
One of the sum native impl can be re-used by others but prod wasn't so lucky.
Let me take a stab at changing the prod signature to use IntList[1]. This should be backwards compatible.
There was a problem hiding this comment.
For the ordering, I tried to keep the existing code where it is except to separate by "dim", "nodim" as before. Let me just reorder everything by function name in a commit of its own.
There was a problem hiding this comment.
Actually, the current organization makes sense in its own way. cumsum and cumprod are not really reduce functions, they are cumulative.
I'll still change to order by function, I already spent way too much mental effort for grasping what's going on here.
There was a problem hiding this comment.
Let me take a stab at changing the prod signature to use IntList[1]. This should be backwards compatible.
Punting on this because prod_backward needs to be rewritten to support multi-dim. This needs to be a separate PR.
There was a problem hiding this comment.
I would recommend ordering by what you think makes sense (with an explanation); I was just describing why the current ordering didn't make sense to me.
| if arg.get('default') is not None: | ||
| default = arg['default'] | ||
| if default == 'nullptr' or default == 'nullopt' or default == '{}': | ||
| if default == 'nullptr' or default == 'c10::nullopt' or default == '{}': |
There was a problem hiding this comment.
I think I asked this before (but perhaps the comment went away because of the force push) -- why do you need the c10?
There was a problem hiding this comment.
Ah, responded on #15154 for this.
native_parse marks the default as c10::nullopt and then that gets propagated everywhere.
Rather than adding a bunch of using statements in generated code, instead replacing the previous check, since that isn't used for anything. Verified by tracing this line in code gen.
There was a problem hiding this comment.
I see; not a fan but you don't have to change it :).
test/test_jit.py
Outdated
| return torch.add(a.cumsum(0, dtype=torch.long).sum(dtype=None), | ||
| b.cumprod(0, dtype=None).prod(dtype=torch.double)) | ||
|
|
||
| a = torch.randn(4, 4, dtype=torch.float, requires_grad=True) |
There was a problem hiding this comment.
do a and b do anything here?
test/test_jit.py
Outdated
| example_outputs=outputs) | ||
|
|
||
| def test_onnx_export_script_module_unsupported_optional(self): | ||
| class ModuleToExport(torch.jit.ScriptModule): |
There was a problem hiding this comment.
how is this different than the test above?
There was a problem hiding this comment.
This one is verifying that dtype can't be exported to ONNX, and raises an exception.
The one above is verifying that the overload switch I added to symbolic.py, i.e. torch.prod(x) matches prod(x, None) export but torch.prod(x, 0) matches prod(x, dim, keepdim, None) export.
There was a problem hiding this comment.
test_onnx_export_script_module_overload_fail and test_onnx_export_script_module_unsupported_optional are character-for-character identical.
There was a problem hiding this comment.
Ah, lol. I wanted to test two different things and they ended up being the exact same code. I remember writing with self.assertRaisesRegex twice for this PR with some time in-between.
Will remove one of them.
test/test_jit.py
Outdated
| return x | ||
|
|
||
| def test_onnx_export_script_module_overload(self): | ||
| class ModuleToExport(torch.jit.ScriptModule): |
There was a problem hiding this comment.
didn't you clobber the test above?
There was a problem hiding this comment.
The above test is verifying the if behavior. Did you mean this comment for the test below?
There was a problem hiding this comment.
No, I don't see how test_onnx_export_script_module_if actually tests anything anymore. It doesn't instantiate a ModelToExport or run it through onnx, etc.
There was a problem hiding this comment.
Ah, right sorry. Saw it now.
tools/autograd/derivatives.yaml
Outdated
| - name: cumprod(Tensor self, int64_t dim) | ||
| self: cumprod_backward(grad, self, dim) | ||
| - name: cumprod(Tensor self, int64_t dim, *, ScalarType? dtype) | ||
| self: cumprod_backward(grad.to(self.type()), self, dim) |
There was a problem hiding this comment.
For all these functions:
-
does this generate an additional cast if we trace it for the case where a dtype isn't passed?
-
Also, can you use scalar_type() instead of type()?
There was a problem hiding this comment.
It doesn't do a cast if a dtype has not been passed to the original forward call:
pytorch/aten/src/ATen/templates/TensorMethods.h
Lines 12 to 16 in 517c7c9
There was a problem hiding this comment.
Sorry, by generate an additional cast I meant "trace an additional cast".
There was a problem hiding this comment.
Sorry for the noob questions. Do you mean JIT tracing the backwards path? In what circumstance will this happen? I've failed to create a test scenario so far.
Though, I assume it won't be traced if the sample input doesn't execute the copy path, which it won't.
>>> torch.tensor(1.0, requires_grad=True).max().to(torch.double).grad_fn
<CopyBackwards object at 0x7f695dd3f898>
>>> torch.tensor(1.0, requires_grad=True).max().to(torch.float).grad_fn
<MaxBackward1 object at 0x7f695dd3f860>
facebook-github-bot
left a comment
There was a problem hiding this comment.
@tugrulates has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@gchanan This should be ready to go after fixing the conflicts. Let me know if there are more changes you'd like to see. |
|
Giving up on this since. |
Fixes #6593
This replaces the separately defined overloads for reduce functions that take
dtypeto take this param byScalarType?, now that it is supported in schema. This simplifies much of the logic to support these overloads and introduces the pattern to takedtypeby kwargs for future needs.This PR has some minor backward breaking changes:
torch.softmaxandtorch.log_softmax, i.e.torch.softmax(x, 0, torch.float)will fail and need to be replaced bytorch.softmax(x, 0, dtype=torch.float).sum(%x)) will not work across versions because kwargs are lost in JIT. See changed expect files for graph changes.retain_graph=Falseand multiple backwards are computed on the same path. See the changed test for a real-life example of this.It might be easier to review this commit-by-commit.