Deprecates current torch.full integral type inference, adds torch.full complex type inference#34709
Deprecates current torch.full integral type inference, adds torch.full complex type inference#34709
Conversation
| "Set the optional dtype or out arguments to suppress this warning. " | ||
| ); | ||
| } else if (fill_value.isComplex()) { | ||
| auto scalar_type = (get_default_dtype() == ScalarType::Double) ? |
There was a problem hiding this comment.
you can just use the get_default_complex_dtype once this PR: https://github.com/pytorch/pytorch/pull/34093/files#diff-2f560eadec29b69291ea551b2eea94a4R20 is landed
| Returns a tensor of size :attr:`size` filled with :attr:`fill_value`. | ||
|
|
||
| .. warning:: | ||
| In PyTorch 1.5 integral `fill_value`s will produce a warning if `out` |
There was a problem hiding this comment.
do you think this is too conservative for this change? I feel like this API isn't used that much and maybe we should speed things up. What is the fbcode usage?
There was a problem hiding this comment.
@gchanan Mike and I talked this over a bit. I can see the argument for skipping the warning and going straight to breaking if the usage is low, though personally I'm still not sure I'd do it so close to the 1.5 cut.
Is skipping the warning the "speed things up" you were thinking of? We were trying to channel your worldview but you know, imperfect vessels 😁
There was a problem hiding this comment.
There are 199 hits searching torch.full in FBCode, which is low compared to other functions (like max) we've looked at. Many of the uses actually specify the dtype, too, and for those that don't an easy fix is to simply write in the dtype argument.
Stepping back: let's talk about concrete deprecation options. Seems like there are three deprecation actions we can take: (1) warn, like this PR does, (2) error, like this PR proposes we do in 1.6, (3) change, like this PR proposes we do in 1.7.
Warning (1) is desirable because it gives people an opportunity to change without their network breaking. Erroring (2) is critical for preventing silent failures. And changing (3) is the entire point. If we refuse to accept the possibility of silently breaking a network than I think we have to do (2) and (3). @bhosmer and I constructed the following scenario: we implement the change, a user gets a LongTensor expecting a FloatTensor, they divide the LongTensor by another LongTensor, this performs integral, not true, division and they get a LongTensor result with different values than the FloatTensor they had previously. This LongTensor is then added to another FloatTensor, producing a FloatTensor (as expected), but with different values than before. This is a silent failure.
If we accept that argument, the only speedup to consider is whether we warn or error in 1.5. I lean warn in 1.5 because (A) we're close to the snap and (B) we haven't formalized our guidance for when we don't need to warn. We also, plan, for example, to warn with torch.div performing integer division in 1.5, cause it to error in 1.6, and change torch.div's behavior in 1.7.
Lastly, we can try to have our cake and eat it, too. That is, we could implement warn as our baseline, then look at a PR that causes the behavior to error and changes the necessary call sites. If that PR comes together in time and we think it's OK to skip a warning release we could land it.
There was a problem hiding this comment.
@bhosmer @mruberry : good discussion.
I'm not sure I would call warning -> changed behavior a silent change of behavior, because it's literally not silent :). But it's obviously not ideal either, I'm just concerned that 6 months is a long time over which to be making these changes. I can only speak for myself, but I don't use torch.full because I don't want to have to think about broken type inference and 6 months is a long time to wait.
I also buy the argument that we are too close to the release branch cut to be confident in making it an error.
I'd propose we "try to have our cake and eat it, too":
- land this with just some ambiguity around the timing, i.e. "torch.full is deprecated and its behavior will change in a later release"
- depending on priorities, we try to upgrade the warning to an error in 1.5. Honestly, 200 occurrences is more than I was expecting; it would be nice to have the number that don't specify a dtype.
Sound good?
There was a problem hiding this comment.
I updated the documentation and warning to be vaguer about future plans.
|
one thing this makes me realize is -- full_like makes questionable sense. Do you expect it to use the "like" dtype or the dtype of the inference value (or some precedence)? |
| torch.full(size, 1, names=('a', 'b'), dtype=torch.float) | ||
|
|
||
| # Tests complex inference | ||
| self.assertTrue(torch.full(size, (1 + 1j)).dtype == torch.complex64) |
There was a problem hiding this comment.
if you want to do this test now, you should probably test the difference default dtypes.
There was a problem hiding this comment.
oh, you do that below? But why not with the annotation?
There was a problem hiding this comment.
There's not a great existing annotation for this since I'm setting torch.half as the default scalar type and the early part of the test is dtype agnostic. I can split the test and use the dtype annotation, however, which will make it less repetitive.
|
Nice work! Adding a few takeaways from our offline review inline... |
| Returns a tensor of size :attr:`size` filled with :attr:`fill_value`. | ||
|
|
||
| .. warning:: | ||
| In PyTorch 1.5 integral `fill_value`s will produce a warning if `out` |
There was a problem hiding this comment.
@gchanan Mike and I talked this over a bit. I can see the argument for skipping the warning and going straight to breaking if the usage is low, though personally I'm still not sure I'd do it so close to the 1.5 cut.
Is skipping the warning the "speed things up" you were thinking of? We were trying to channel your worldview but you know, imperfect vessels 😁
@gchanan per offline discussion w/Mike, he says numpy's |
💊 CircleCI build failures summary and remediationsAs of commit 73329d5 (more details on the Dr. CI page): ✅ None of the build failures appear to be your fault 💚
❄️ 1 tentatively flaky failure1 failure tentatively classified as flaky but have not triggered reruns to confirm:
|
|
@bhosmer see my question above about |
|
just to follow up on the precedence issue with so the fill_value can't possibly change the dtype. Which appears to be the right behavior, but not really part of the precedence. |
Ah makes sense. |
| "Deprecation warning: In PyTorch 1.6 torch.full with a bool or ", | ||
| "integral fill_value will require a dtype or out argument. ", | ||
| "In PyTorch 1.7, when `out` and `dtype` are not set a bool fill_value ", | ||
| "will return a tensor of torch.bool dtype, ", |
There was a problem hiding this comment.
I think we should give separate messages for if the user called the out variant or not.
The reasoning is basically, I don't think the out variants are a particularly good idea -- because these are factory functions, they specify properties of the memory of the tensor, i.e. the memory format. What does it mean for the memory format to be one thing, but the out argument to be another? Best if we limit the usage of the out= variants.
There was a problem hiding this comment.
The error message for setting the out= kwarg improperly is in python_torch_functions.py. I added a test for the behavior.
gchanan
left a comment
There was a problem hiding this comment.
this looks pretty close to good to go. I think the action items are:
- separate out the warnings for the out= vs not case.
- (optionally) make the docs/warning timing more ambiguous, since we aren't sure yet.
- make sure we have a test for full_like precedence inference.
d328d1e to
18b74fd
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Related: #34210 in the part: does |
|
partially addresses #27171 |
|
From offline discussion: this itself isn't BC-breaking, but we will break BC sometime in the next release. I'm removing the label to help better keep track of this |
…l complex type inference (pytorch#34709) Summary: Per title. Currently torch.full will always (attempt to) produce a float tensor. This is inconsistent with NumPy in (at least) two cases: - When integral fill values (including bool) are given - When complex fill values are given For example: ``` np.full((1, 2), 1).dtype : dtype('int64') np.full((1, 2), (1 + 1j)).dtype : dtype('complex128') ``` Whereas in PyTorch ``` torch.full((1, 2), 1).dtype : torch.float32 torch.full((1, 2), (1 + 1j)).dtype : RuntimeError: value cannot be converted to type float without overflow: (1,1) ``` This PR begins the process of deprecating our current behavior of returning float tensors (by default) when given integer fill values by warning the user that integer fill values will require explicitly specifying the dtype or out kwargs in 1.6, and in 1.7 the behavior will change to return a LongTensor by default (BoolTensor for bool values). The intermediate 1.6 release is to prevent changing the behavior silently and unexpectedly. The PR also implements inference for complex types. So that with it: ``` torch.full((1, 2), (1 + 1j)).dtype : torch.complex64 ``` The complex type inference returns a ComplexFloat tensor when given a complex fill value (and no dtype or out kwarg is specified), unless the default dtype is Double, in which case a ComplexDouble tensor is returned. A test for these behaviors is added to test_torch.py. Implementation note: This PR required customizing full's dispatch because currently in eager codegen the TensorOptions object passed to functions improperly sets has_dtype() to true, even if the user did not explicitly provide a dtype. torch.arange already worked around this issue with its own custom implementation. The JIT, however, does pass a properly constructed TensorOptions object. Future Work: This PR does not extend torch.full's complex type inference to ONNX. This seems unlikely to come up and will be a clear error if it does. When integer type inference is added to torch.full, however, then porting the behavior to ONNX may be warranted. torch.arange ported its complex type promotion logic to ONNX, for example. Additionally, this PR mostly leaves existing call sites in PyTorch that would trigger this warning intact. This is to be more minimal (since the PR is BC breaking). I will submit a separate PR fixing PyTorch's call sites. Pull Request resolved: pytorch#34709 Differential Revision: D20509387 Pulled By: mruberry fbshipit-source-id: 129593ba06a1662032bbbf8056975eaa59baf933
UPDATE
BC-breaking release note info:
In a future PyTorch release, torch.full will infer its dtype from its fill value when the optional dtype and out parameters are unspecified. For example, torch.full(size, 1) will return a tensor of torch.long dtype, unlike today where it returns a tensor of torch.float dtype.
// ------- (PR info below) ------- //
Per title.
Currently torch.full will always (attempt to) produce a float tensor. This is inconsistent with NumPy in (at least) two cases:
For example:
Whereas in PyTorch
This PR begins the process of deprecating our current behavior of returning float tensors (by default) when given integer fill values by warning the user that integer fill values will require explicitly specifying the dtype or out kwargs in 1.6, and in 1.7 the behavior will change to return a LongTensor by default (BoolTensor for bool values). The intermediate 1.6 release is to prevent changing the behavior silently and unexpectedly.
The PR also implements inference for complex types. So that with it:
The complex type inference returns a ComplexFloat tensor when given a complex fill value (and no dtype or out kwarg is specified), unless the default dtype is Double, in which case a ComplexDouble tensor is returned.
A test for these behaviors is added to test_torch.py.
Implementation note:
This PR required customizing full's dispatch because currently in eager codegen the TensorOptions object passed to functions improperly sets has_dtype() to true, even if the user did not explicitly provide a dtype. torch.arange already worked around this issue with its own custom implementation. The JIT, however, does pass a properly constructed TensorOptions object.
Future Work:
This PR does not extend torch.full's complex type inference to ONNX. This seems unlikely to come up and will be a clear error if it does. When integer type inference is added to torch.full, however, then porting the behavior to ONNX may be warranted. torch.arange ported its complex type promotion logic to ONNX, for example.
Additionally, this PR mostly leaves existing call sites in PyTorch that would trigger this warning intact. This is to be more minimal (since the PR is BC breaking). I will submit a separate PR fixing PyTorch's call sites.