Skip to content

torch.float_power out= and inplace variant errors on non-matching output dtype instead of casting #50213

@gchanan

Description

@gchanan

🐛 Bug

Operators usually check that the computation type can be safely casted to the output (and this matches NumPy behavior), and this doesn't happen with float_power.

>>> torch.float_power(torch.randn(2,3), 5., out=torch.randn(2,3))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: output type Floatis not the desired output type Double
>>> torch.randn(2,3).float_power_(5.)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: self tensor type Floatis not the desired type Double
>>> np.float_power(np.random.randn(2,3).astype(np.float32), 5., out=np.random.randn(2,3).astype(np.float32))
array([[-1.1256918e-01,  5.3110137e+00, -3.6077526e+01],
       [-1.2552988e-02, -3.9551861e-04,  2.7541189e+01]], dtype=float32)

this violates https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch

Collecting environment information...
PyTorch version: 1.8.0.dev20201222
Is debug build: False
CUDA used to build PyTorch: 9.2
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
Clang version: Could not collect
CMake version: version 3.14.0

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla M40
GPU 1: Tesla M40

Nvidia driver version: 418.126.02
cuDNN version: /usr/local/cuda-9.2/targets/x86_64-linux/lib/libcudnn.so.7.4.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] pytorch-lightning==0.9.0
[pip] torch==1.8.0.dev20201222
[pip] torchaudio==0.8.0a0+1398187
[pip] torchvision==0.9.0.dev20201222
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               9.2                           0
[conda] mkl                       2020.1                      217
[conda] mkl-include               2020.1                      217
[conda] mkl-service               2.3.0            py37he904b0f_0
[conda] mkl_fft                   1.0.14           py37ha843d7b_0
[conda] mkl_random                1.1.0            py37hd6b4f25_0
[conda] numpy                     1.18.1           py37h4f9e942_0
[conda] numpy-base                1.18.1           py37hde5b4d6_1
[conda] pytorch                   1.8.0.dev20201222 py3.7_cuda9.2.148_cudnn7.6.3_0    pytorch-nightly
[conda] pytorch-lightning         0.9.0                     <pip>
[conda] torch                     1.8.0.dev20201130+cpu           <pip>
[conda] torchaudio                0.8.0.dev20201130           <pip>
[conda] torchaudio                0.8.0.dev20201222            py37    pytorch-nightly
[conda] torchvision               0.9.0.dev20201222       py37_cu92    pytorch-nightly

cc @nairbv @mruberry

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions