Skip to content

Optional dtype for reduce functions#15133

Closed
tugrulates wants to merge 22 commits intopytorch:masterfrom
tugrulates:optional-scalartype
Closed

Optional dtype for reduce functions#15133
tugrulates wants to merge 22 commits intopytorch:masterfrom
tugrulates:optional-scalartype

Conversation

@tugrulates
Copy link

@tugrulates tugrulates commented Dec 12, 2018

Fixes #6593

This replaces the separately defined overloads for reduce functions that take dtype to take this param by ScalarType?, now that it is supported in schema. This simplifies much of the logic to support these overloads and introduces the pattern to take dtype by kwargs for future needs.

This PR has some minor backward breaking changes:

  • dtype by arg is being deleted for torch.softmax and torch.log_softmax, i.e. torch.softmax(x, 0, torch.float) will fail and need to be replaced by torch.softmax(x, 0, dtype=torch.float).
  • Calls within script to these functions without dtype (e.g. sum(%x)) will not work across versions because kwargs are lost in JIT. See changed expect files for graph changes.
  • Some autograd graphs will change from one depth to two depths, due to type cast in the backward functions. This will break if retain_graph=False and multiple backwards are computed on the same path. See the changed test for a real-life example of this.

It might be easier to review this commit-by-commit.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 12, 2018
facebook-github-bot pushed a commit that referenced this pull request Dec 19, 2018
Summary:
For #6593 and #9515

This completes the support for optional<ScalarType> in native, JIT and autograd.

Note: Mostly following the existing implementation for optional<Scalar> that was added in #12582.

This PR introduces a way to make functions accept an optional dtype and it will unblock #9515 by allowing the `dtype` param for type promotion interface:
```
func: name(inputs, *, ScalarType? dtype=None, Casting casting=same_kind)
```

An alternative approach could have been using `ScalarType::Undefined` for the same purpose but without optional, though it would have been a bit hacky.
```
func: name(inputs, *, ScalarType dtype=Undefined, Casting casting=same_kind)
```

Here's an example use of this in action: 971f69e

There are already a bunch of native functions that were getting optional `dtype` through function overloading. #15133 is the attempt to migrate all of those. I will send those changes separately after this since some functions (e.g. sum) need quite a bit of change in the codebase. See the commits over there.
Pull Request resolved: #15154

Differential Revision: D13457760

Pulled By: tugrulates

fbshipit-source-id: 706134f0bd578683edd416b96329b49a1ba8ab48
d_intermediates = [d_i for intermediates_batch in group(intermediates, shard_size)
for d_i in torch.autograd.grad(loss, intermediates_batch)]
for d_i in torch.autograd.grad(loss, intermediates_batch,
retain_graph=True)]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the graph has one more depth due to type cast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above, if we can avoid the casts and avoid doing this that seems ideal.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type cast will be no-op when dtype is not passed, but still adds another node to the graph. To avoid it altogether, we'll need to provide an overload to the native function without the dtype kwarg.

The behavior that is surprising is that retain_graph not being needed for the simplest backwards cases. The documentation says the grad graph will be freed, yet this test still manages to use it more than once. This is not a problem for just this PR. This quirk will apply with any change to any derivative, when the node count goes from one to two.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the addition of the None node that causes you to need retain_graph?

@tugrulates
Copy link
Author

This is now ready to review.

@tugrulates tugrulates changed the title [Not ready for review] Optional scalartype for native functions Optional dtype for reduce functions Dec 20, 2018
});
return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr};

} else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already working fine through dispatch and derivatives.yaml. Rather than updating here, simply dropping the override.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nm, this seems to be used in symbolic. diff. Will bring back.

variants: function, method

- func: softmax(Tensor self, int64_t dim) -> Tensor
- func: softmax(Tensor self, int64_t dim, *, ScalarType? dtype=None) -> Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For softmax and log_softmax, why do we want to make it a kw only argument instead of simply merging them together? It is indeed a BC change and I don't know if we want to do that or not.

Copy link
Author

@tugrulates tugrulates Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mainly for the sake of consistency with other ops.

Some pro arguments are:

  • These two are undocumented functions. The public interface is through torch.nn.functional.[log_]softmax which already have dtype as a kwarg.
  • kwarg and non-kwarg forms are the same cpp signature so it's tricky to make them live side-by-side for a deprecation period.
  • If this change is ever to be made, it's better to do it sooner than later.
  • It's a better baseline ever if more kwargs are added across multiple ops (casting, where etc.).
  • The fix is trivial if any code is broken by this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, as long as it only expose from nn.functional I think would be fine :)

%14 : Tensor = aten::gt(%13, %11)
%15 : bool = prim::Bool(%14)
%16 : Tensor, %17 : Tensor, %a : Tensor, %19 : Tensor, %20 : Tensor = prim::Loop(%7, %15, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
%13 : int? = prim::None()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes are pretty noisy and I'm assuming these tests already existed (so your casts are all no-ops). Can we just avoid actually calling the casts if the casts are no-op to preserve these? Is that possible?

Copy link
Author

@tugrulates tugrulates Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not calling a cast, but adding a prim::None to sum symbol in line 15 below.

This is because this PR is removing the overloads without the dtypes, which is kind of the purpose of it.

Since the generated graphs include the default values from the schema, this will keep happening with every new kwarg added to ops. Two ways to fix this:

  1. Keep overloads in native_functions.yaml and keep adding them with every new arg. (makes this PR unneeded)
  2. Make JIT apply defaults at execution time. Maybe this behavior might be confined to kwargs only. (much larger scope than what this PR trying to achieve)

d_intermediates = [d_i for intermediates_batch in group(intermediates, shard_size)
for d_i in torch.autograd.grad(loss, intermediates_batch)]
for d_i in torch.autograd.grad(loss, intermediates_batch,
retain_graph=True)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above, if we can avoid the casts and avoid doing this that seems ideal.

Tensor prod(const Tensor &self, optional<ScalarType> dtype) {
Tensor result;
native::prod_out(result, self, dim, keepdim, dtype);
return at::native::prod_impl(result, self, {}, false, dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have two returns here?

Tensor prod(const Tensor &self, ScalarType dtype) {
return at::native::prod(self, {}, false, optional<ScalarType>(dtype));
}
// \ALL REDUCE ################################################################
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these comments don't seem relevant anymore.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an attempt at some point but it seems to have been broken. Let me reorder by primary op.


static Tensor& prod_out(Tensor& result, const Tensor& self, IntList dim,
bool keepdim, optional<ScalarType> opt_dtype) {
static Tensor& prod_impl(Tensor& result, const Tensor& self, IntList dim,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand the organization here. It's not grouped by function name, which makes it difficult to figure out how a specific function is implemented. It's also not really grouped by implementation type, i.e. some functions have fn_impl, some fn_out, etc. I'd just go back to function name grouping.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a prod_impl but not a sum_impl? What's the distinction?

Copy link
Author

@tugrulates tugrulates Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the signatures are different:

  • prod(Tensor self, int64_t dim, ...
  • sum(Tensor self, IntList[1] dim, ...

One of the sum native impl can be re-used by others but prod wasn't so lucky.

Let me take a stab at changing the prod signature to use IntList[1]. This should be backwards compatible.

Copy link
Author

@tugrulates tugrulates Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the ordering, I tried to keep the existing code where it is except to separate by "dim", "nodim" as before. Let me just reorder everything by function name in a commit of its own.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the current organization makes sense in its own way. cumsum and cumprod are not really reduce functions, they are cumulative.

I'll still change to order by function, I already spent way too much mental effort for grasping what's going on here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me take a stab at changing the prod signature to use IntList[1]. This should be backwards compatible.

Punting on this because prod_backward needs to be rewritten to support multi-dim. This needs to be a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend ordering by what you think makes sense (with an explanation); I was just describing why the current ordering didn't make sense to me.

if arg.get('default') is not None:
default = arg['default']
if default == 'nullptr' or default == 'nullopt' or default == '{}':
if default == 'nullptr' or default == 'c10::nullopt' or default == '{}':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I asked this before (but perhaps the comment went away because of the force push) -- why do you need the c10?

Copy link
Author

@tugrulates tugrulates Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, responded on #15154 for this.

native_parse marks the default as c10::nullopt and then that gets propagated everywhere.

Rather than adding a bunch of using statements in generated code, instead replacing the previous check, since that isn't used for anything. Verified by tracing this line in code gen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see; not a fan but you don't have to change it :).

test/test_jit.py Outdated
return torch.add(a.cumsum(0, dtype=torch.long).sum(dtype=None),
b.cumprod(0, dtype=None).prod(dtype=torch.double))

a = torch.randn(4, 4, dtype=torch.float, requires_grad=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do a and b do anything here?

test/test_jit.py Outdated
example_outputs=outputs)

def test_onnx_export_script_module_unsupported_optional(self):
class ModuleToExport(torch.jit.ScriptModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this different than the test above?

Copy link
Author

@tugrulates tugrulates Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is verifying that dtype can't be exported to ONNX, and raises an exception.

The one above is verifying that the overload switch I added to symbolic.py, i.e. torch.prod(x) matches prod(x, None) export but torch.prod(x, 0) matches prod(x, dim, keepdim, None) export.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_onnx_export_script_module_overload_fail and test_onnx_export_script_module_unsupported_optional are character-for-character identical.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, lol. I wanted to test two different things and they ended up being the exact same code. I remember writing with self.assertRaisesRegex twice for this PR with some time in-between.

Will remove one of them.

test/test_jit.py Outdated
return x

def test_onnx_export_script_module_overload(self):
class ModuleToExport(torch.jit.ScriptModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't you clobber the test above?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above test is verifying the if behavior. Did you mean this comment for the test below?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't see how test_onnx_export_script_module_if actually tests anything anymore. It doesn't instantiate a ModelToExport or run it through onnx, etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right sorry. Saw it now.

- name: cumprod(Tensor self, int64_t dim)
self: cumprod_backward(grad, self, dim)
- name: cumprod(Tensor self, int64_t dim, *, ScalarType? dtype)
self: cumprod_backward(grad.to(self.type()), self, dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all these functions:

  • does this generate an additional cast if we trace it for the case where a dtype isn't passed?

  • Also, can you use scalar_type() instead of type()?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't do a cast if a dtype has not been passed to the original forward call:

inline Tensor Tensor::toType(const Type & t, bool non_blocking) const {
if(type() == t)
return *this;
return t.copy(*this, non_blocking);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, by generate an additional cast I meant "trace an additional cast".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the noob questions. Do you mean JIT tracing the backwards path? In what circumstance will this happen? I've failed to create a test scenario so far.

Though, I assume it won't be traced if the sample input doesn't execute the copy path, which it won't.

>>> torch.tensor(1.0, requires_grad=True).max().to(torch.double).grad_fn
<CopyBackwards object at 0x7f695dd3f898>
>>> torch.tensor(1.0, requires_grad=True).max().to(torch.float).grad_fn
<MaxBackward1 object at 0x7f695dd3f860>

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tugrulates has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@tugrulates
Copy link
Author

@gchanan This should be ready to go after fixing the conflicts. Let me know if there are more changes you'd like to see.

@tugrulates
Copy link
Author

Giving up on this since.

@tugrulates tugrulates closed this Jan 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants