Skip to content

Deprecates current torch.full integral type inference, adds torch.full complex type inference#34709

Closed
mruberry wants to merge 10 commits intomasterfrom
full_type_inference
Closed

Deprecates current torch.full integral type inference, adds torch.full complex type inference#34709
mruberry wants to merge 10 commits intomasterfrom
full_type_inference

Conversation

@mruberry
Copy link
Copy Markdown
Collaborator

@mruberry mruberry commented Mar 13, 2020

UPDATE

BC-breaking release note info:

In a future PyTorch release, torch.full will infer its dtype from its fill value when the optional dtype and out parameters are unspecified. For example, torch.full(size, 1) will return a tensor of torch.long dtype, unlike today where it returns a tensor of torch.float dtype.

// ------- (PR info below) ------- //

Per title.

Currently torch.full will always (attempt to) produce a float tensor. This is inconsistent with NumPy in (at least) two cases:

  • When integral fill values (including bool) are given
  • When complex fill values are given

For example:

np.full((1, 2), 1).dtype
: dtype('int64')

np.full((1, 2), (1 + 1j)).dtype
: dtype('complex128')

Whereas in PyTorch

torch.full((1, 2), 1).dtype
: torch.float32

torch.full((1, 2), (1 + 1j)).dtype
: RuntimeError: value cannot be converted to type float without overflow: (1,1)

This PR begins the process of deprecating our current behavior of returning float tensors (by default) when given integer fill values by warning the user that integer fill values will require explicitly specifying the dtype or out kwargs in 1.6, and in 1.7 the behavior will change to return a LongTensor by default (BoolTensor for bool values). The intermediate 1.6 release is to prevent changing the behavior silently and unexpectedly.

The PR also implements inference for complex types. So that with it:

torch.full((1, 2), (1 + 1j)).dtype
: torch.complex64

The complex type inference returns a ComplexFloat tensor when given a complex fill value (and no dtype or out kwarg is specified), unless the default dtype is Double, in which case a ComplexDouble tensor is returned.

A test for these behaviors is added to test_torch.py.

Implementation note:

This PR required customizing full's dispatch because currently in eager codegen the TensorOptions object passed to functions improperly sets has_dtype() to true, even if the user did not explicitly provide a dtype. torch.arange already worked around this issue with its own custom implementation. The JIT, however, does pass a properly constructed TensorOptions object.

Future Work:

This PR does not extend torch.full's complex type inference to ONNX. This seems unlikely to come up and will be a clear error if it does. When integer type inference is added to torch.full, however, then porting the behavior to ONNX may be warranted. torch.arange ported its complex type promotion logic to ONNX, for example.

Additionally, this PR mostly leaves existing call sites in PyTorch that would trigger this warning intact. This is to be more minimal (since the PR is BC breaking). I will submit a separate PR fixing PyTorch's call sites.

@mruberry mruberry added module: bc-breaking Related to a BC-breaking change module: complex Related to complex number support in PyTorch module: numpy Related to numpy support, and also numpy compatibility of our operators labels Mar 13, 2020
"Set the optional dtype or out arguments to suppress this warning. "
);
} else if (fill_value.isComplex()) {
auto scalar_type = (get_default_dtype() == ScalarType::Double) ?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just use the get_default_complex_dtype once this PR: https://github.com/pytorch/pytorch/pull/34093/files#diff-2f560eadec29b69291ea551b2eea94a4R20 is landed

Comment thread aten/src/ATen/native/TensorFactories.cpp Outdated
Comment thread tools/autograd/gen_python_functions.py Outdated
Comment thread aten/src/ATen/native/TensorFactories.cpp
Comment thread aten/src/ATen/native/TensorFactories.cpp
Comment thread torch/_torch_docs.py Outdated
Returns a tensor of size :attr:`size` filled with :attr:`fill_value`.

.. warning::
In PyTorch 1.5 integral `fill_value`s will produce a warning if `out`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think this is too conservative for this change? I feel like this API isn't used that much and maybe we should speed things up. What is the fbcode usage?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan Mike and I talked this over a bit. I can see the argument for skipping the warning and going straight to breaking if the usage is low, though personally I'm still not sure I'd do it so close to the 1.5 cut.

Is skipping the warning the "speed things up" you were thinking of? We were trying to channel your worldview but you know, imperfect vessels 😁

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 199 hits searching torch.full in FBCode, which is low compared to other functions (like max) we've looked at. Many of the uses actually specify the dtype, too, and for those that don't an easy fix is to simply write in the dtype argument.

Stepping back: let's talk about concrete deprecation options. Seems like there are three deprecation actions we can take: (1) warn, like this PR does, (2) error, like this PR proposes we do in 1.6, (3) change, like this PR proposes we do in 1.7.

Warning (1) is desirable because it gives people an opportunity to change without their network breaking. Erroring (2) is critical for preventing silent failures. And changing (3) is the entire point. If we refuse to accept the possibility of silently breaking a network than I think we have to do (2) and (3). @bhosmer and I constructed the following scenario: we implement the change, a user gets a LongTensor expecting a FloatTensor, they divide the LongTensor by another LongTensor, this performs integral, not true, division and they get a LongTensor result with different values than the FloatTensor they had previously. This LongTensor is then added to another FloatTensor, producing a FloatTensor (as expected), but with different values than before. This is a silent failure.

If we accept that argument, the only speedup to consider is whether we warn or error in 1.5. I lean warn in 1.5 because (A) we're close to the snap and (B) we haven't formalized our guidance for when we don't need to warn. We also, plan, for example, to warn with torch.div performing integer division in 1.5, cause it to error in 1.6, and change torch.div's behavior in 1.7.

Lastly, we can try to have our cake and eat it, too. That is, we could implement warn as our baseline, then look at a PR that causes the behavior to error and changes the necessary call sites. If that PR comes together in time and we think it's OK to skip a warning release we could land it.

Copy link
Copy Markdown
Contributor

@gchanan gchanan Mar 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bhosmer @mruberry : good discussion.

I'm not sure I would call warning -> changed behavior a silent change of behavior, because it's literally not silent :). But it's obviously not ideal either, I'm just concerned that 6 months is a long time over which to be making these changes. I can only speak for myself, but I don't use torch.full because I don't want to have to think about broken type inference and 6 months is a long time to wait.

I also buy the argument that we are too close to the release branch cut to be confident in making it an error.

I'd propose we "try to have our cake and eat it, too":

  1. land this with just some ambiguity around the timing, i.e. "torch.full is deprecated and its behavior will change in a later release"
  2. depending on priorities, we try to upgrade the warning to an error in 1.5. Honestly, 200 occurrences is more than I was expecting; it would be nice to have the number that don't specify a dtype.

Sound good?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the documentation and warning to be vaguer about future plans.

Comment thread tools/pyi/gen_pyi.py
@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented Mar 13, 2020

one thing this makes me realize is -- full_like makes questionable sense. Do you expect it to use the "like" dtype or the dtype of the inference value (or some precedence)?

Comment thread test/test_torch.py Outdated
torch.full(size, 1, names=('a', 'b'), dtype=torch.float)

# Tests complex inference
self.assertTrue(torch.full(size, (1 + 1j)).dtype == torch.complex64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to do this test now, you should probably test the difference default dtypes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, you do that below? But why not with the annotation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's not a great existing annotation for this since I'm setting torch.half as the default scalar type and the early part of the test is dtype agnostic. I can split the test and use the dtype annotation, however, which will make it less repetitive.

@bhosmer
Copy link
Copy Markdown

bhosmer commented Mar 14, 2020

Nice work! Adding a few takeaways from our offline review inline...

Comment thread tools/autograd/templates/python_torch_functions.cpp
Comment thread torch/_torch_docs.py Outdated
Returns a tensor of size :attr:`size` filled with :attr:`fill_value`.

.. warning::
In PyTorch 1.5 integral `fill_value`s will produce a warning if `out`
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan Mike and I talked this over a bit. I can see the argument for skipping the warning and going straight to breaking if the usage is low, though personally I'm still not sure I'd do it so close to the 1.5 cut.

Is skipping the warning the "speed things up" you were thinking of? We were trying to channel your worldview but you know, imperfect vessels 😁

@bhosmer
Copy link
Copy Markdown

bhosmer commented Mar 14, 2020

one thing this makes me realize is -- full_like makes questionable sense. Do you expect it to use the "like" dtype or the dtype of the inference value (or some precedence)?

@gchanan per offline discussion w/Mike, he says numpy's full_like uses the precedence order fill_value < self < dtype. On a quick check, it looks like we do too. (So I think the upshot is: because fill_value is beaten by self in case of conflict, full_like doesn't have the dtype issue that this PR fixes for full: the relationship between self and dtype is already handled properly.)

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Mar 15, 2020

💊 CircleCI build failures summary and remediations

As of commit 73329d5 (more details on the Dr. CI page):


None of the build failures appear to be your fault 💚


  • 1/2 tentatively recognized as flaky ❄️

  • 1/2 broken upstream at merge base a1eaaea since Mar 18

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase FETCH_HEAD
    

    Check out the recency history of this "viable master" tracking branch.


❄️ 1 tentatively flaky failure

1 failure tentatively classified as flaky but have not triggered reruns to confirm:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (1/1)

Step: "Test" (full log | pattern match details) ❄️

Mar 18 12:33:20 RuntimeError: Process 0 terminated or timed out after 100.04592323303223 seconds
Mar 18 12:33:20 ====================================================================== 
Mar 18 12:33:20 ERROR [100.218s]: test_dist_optim (__main__.DistOptimizerTestWithSpawn) 
Mar 18 12:33:20 ---------------------------------------------------------------------- 
Mar 18 12:33:20 Traceback (most recent call last): 
Mar 18 12:33:20   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 175, in wrapper 
Mar 18 12:33:20     self._join_processes(fn) 
Mar 18 12:33:20   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 285, in _join_processes 
Mar 18 12:33:20     self._check_return_codes(elapsed_time) 
Mar 18 12:33:20   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 323, in _check_return_codes 
Mar 18 12:33:20     raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time)) 
Mar 18 12:33:20 RuntimeError: Process 0 terminated or timed out after 100.04592323303223 seconds 
Mar 18 12:33:20  
Mar 18 12:33:20 ---------------------------------------------------------------------- 
Mar 18 12:33:20 Ran 3 tests in 102.558s 
Mar 18 12:33:20  
Mar 18 12:33:20 FAILED (errors=1) 
Mar 18 12:33:20  
Mar 18 12:33:20 Generating XML reports... 
Mar 18 12:33:21 Traceback (most recent call last): 
Mar 18 12:33:21   File "test/run_test.py", line 674, in <module> 
Mar 18 12:33:21     main() 

🚧 1 upstream failure:

These were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 8 times.

@mruberry mruberry requested a review from gchanan March 15, 2020 21:06
@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented Mar 17, 2020

@bhosmer see my question above about fill_value < self < dtype -- I think in our case, self is always specified, so the fill_value issue is moot.

@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented Mar 17, 2020

just to follow up on the precedence issue with full_like: it's implemented as:

Tensor full_like(
    const Tensor& self,
    Scalar fill_value,
    const TensorOptions& options,
    c10::optional<c10::MemoryFormat> optional_memory_format) {
  auto result = at::empty_like(self, options, optional_memory_format);
  return result.fill_(fill_value);
}

so the fill_value can't possibly change the dtype. Which appears to be the right behavior, but not really part of the precedence.

@bhosmer
Copy link
Copy Markdown

bhosmer commented Mar 17, 2020

@bhosmer see my question above about fill_value < self < dtype -- I think in our case, self is always specified, so the fill_value issue is moot.

Ah makes sense.

Comment thread aten/src/ATen/native/TensorFactories.cpp Outdated
"Deprecation warning: In PyTorch 1.6 torch.full with a bool or ",
"integral fill_value will require a dtype or out argument. ",
"In PyTorch 1.7, when `out` and `dtype` are not set a bool fill_value ",
"will return a tensor of torch.bool dtype, ",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should give separate messages for if the user called the out variant or not.

The reasoning is basically, I don't think the out variants are a particularly good idea -- because these are factory functions, they specify properties of the memory of the tensor, i.e. the memory format. What does it mean for the memory format to be one thing, but the out argument to be another? Best if we limit the usage of the out= variants.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message for setting the out= kwarg improperly is in python_torch_functions.py. I added a test for the behavior.

Copy link
Copy Markdown
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks pretty close to good to go. I think the action items are:

  1. separate out the warnings for the out= vs not case.
  2. (optionally) make the docs/warning timing more ambiguous, since we aren't sure yet.
  3. make sure we have a test for full_like precedence inference.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vadimkantorov
Copy link
Copy Markdown
Contributor

vadimkantorov commented Mar 18, 2020

Related: #34210 in the part: does torch.full support PyTorch scalar fill_value and does it produce gradient wrt fill_value?

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in 1afc584.

@mruberry
Copy link
Copy Markdown
Collaborator Author

partially addresses #27171

@mruberry mruberry deleted the full_type_inference branch March 29, 2020 07:12
@zou3519 zou3519 removed the module: bc-breaking Related to a BC-breaking change label Apr 9, 2020
@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Apr 9, 2020

From offline discussion: this itself isn't BC-breaking, but we will break BC sometime in the next release. I'm removing the label to help better keep track of this

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…l complex type inference (pytorch#34709)

Summary:
Per title.

Currently torch.full will always (attempt to) produce a float tensor. This is inconsistent with NumPy in (at least) two cases:

- When integral fill values (including bool) are given
- When complex fill values are given

For example:

```
np.full((1, 2), 1).dtype
: dtype('int64')

np.full((1, 2), (1 + 1j)).dtype
: dtype('complex128')
```

Whereas in PyTorch

```
torch.full((1, 2), 1).dtype
: torch.float32

torch.full((1, 2), (1 + 1j)).dtype
: RuntimeError: value cannot be converted to type float without overflow: (1,1)
```

This PR begins the process of deprecating our current behavior of returning float tensors (by default) when given integer fill values by warning the user that integer fill values will require explicitly specifying the dtype or out kwargs in 1.6, and in 1.7 the behavior will change to return a LongTensor by default (BoolTensor for bool values). The intermediate 1.6 release is to prevent changing the behavior silently and unexpectedly.

The PR also implements inference for complex types. So that with it:

```
torch.full((1, 2), (1 + 1j)).dtype
: torch.complex64
```

The complex type inference returns a ComplexFloat tensor when given a complex fill value (and no dtype or out kwarg is specified), unless the default dtype is Double, in which case a ComplexDouble tensor is returned.

A test for these behaviors is added to test_torch.py.

Implementation note:

This PR required customizing full's dispatch because currently in eager codegen the TensorOptions object passed to functions improperly sets has_dtype() to true, even if the user did not explicitly provide a dtype. torch.arange already worked around this issue with its own custom implementation. The JIT, however, does pass a properly constructed TensorOptions object.

Future Work:

This PR does not extend torch.full's complex type inference to ONNX. This seems unlikely to come up and will be a clear error if it does. When integer type inference is added to torch.full, however, then porting the behavior to ONNX may be warranted. torch.arange ported its complex type promotion logic to ONNX, for example.

Additionally, this PR mostly leaves existing call sites in PyTorch that would trigger this warning intact. This is to be more minimal (since the PR is BC breaking). I will submit a separate PR fixing PyTorch's call sites.
Pull Request resolved: pytorch#34709

Differential Revision: D20509387

Pulled By: mruberry

fbshipit-source-id: 129593ba06a1662032bbbf8056975eaa59baf933
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: complex Related to complex number support in PyTorch module: deprecation module: numpy Related to numpy support, and also numpy compatibility of our operators

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants