Skip to content

torch.angle() grads are wrong #46144

@jonashaag

Description

@jonashaag

🐛 Bug

I'm not sure it's the grads that are wrong, but I can say for sure that training with .angle() diverges, while using .atan2() converges.

To Reproduce

import numpy as np
import torch


torch.manual_seed(42)


class ComplexLinear(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.re_weights = torch.nn.Linear(1, 1)
        self.im_weights = torch.nn.Linear(1, 1)

    def forward(self, x):
        multiplied = torch.view_as_complex(torch.stack(
            [self.re_weights(x.real), self.im_weights(x.imag)], dim=-1))
        # NOTE: Change between 1 and 0 to select any of the cases.
        if 0:
            # Return magnitude
            if 1:
                # "My" definition
                return (multiplied.real ** 2 + multiplied.imag ** 2).sqrt()
            else:
                # PyTorch definition
                return multiplied.abs()
        elif 1:
            # Return angle
            if 1:  # <== change this to 0 then it diverges
                # "My" definition
                return multiplied.imag.atan2(multiplied.real)
            else:
                # PyTorch definition
                return multiplied.angle()
        else:
            # Return sum of magnitude + angle
            if 1:
                # "My" definition
                return (multiplied.real ** 2 + multiplied.imag ** 2).sqrt() + multiplied.imag.atan2(multiplied.real)
            else:
                # PyTorch definition
                return multiplied.abs() + multiplied.angle()


net = ComplexLinear()

x = torch.from_numpy(np.array([5 + 3j], dtype="complex64"))
y = torch.from_numpy(np.array([0.3], dtype="float32"))

for i in range(10000):
    res = net(x)
    loss = torch.nn.functional.mse_loss(y, res)
    net.zero_grad()
    loss.backward()
    #print(res, net.re_weights.weight.grad)
    print("\r", "step", i, res, end="")
    with torch.no_grad():
        for param in net.parameters():
            if param.requires_grad:
                param -= 1e-3 * param.grad
print()

Output of angle target, my definition:

 step 9999 tensor([0.3000], grad_fn=<Atan2Backward>)

PyTorch definition:

 step 9999 tensor([-0.5408], grad_fn=<AngleBackward>)

Magnitude target seems ok, sum target is broken as well.

Expected behavior

Behaviour of "my" definition and PyTorch definition should be identical.

Environment

PyTorch version: 1.8.0.dev20201009+cu110
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.4 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 455.23.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] numpy==1.18.5
[pip] pytorch-lightning==0.10.0
[pip] pytorch-ranger==0.1.1
[pip] torch==1.8.0.dev20201009+cu110
[pip] torch-optimizer==0.0.1a15
[pip] torch-stoi==0.1.1
[pip] torchaudio==0.7.0.dev20201009
[pip] torchvision==0.8.0.dev20201009+cu110
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] libblas                   3.8.0                    16_mkl    conda-forge
[conda] libcblas                  3.8.0                    16_mkl    conda-forge
[conda] liblapack                 3.8.0                    16_mkl    conda-forge
[conda] mkl                       2020.1                      217
[conda] mkl-service               2.3.0            py37he904b0f_0
[conda] mkl_fft                   1.1.0            py37h23d657b_0
[conda] mkl_random                1.1.1            py37h0573a6f_0
[conda] numpy                     1.18.5           py37ha1c710e_0
[conda] numpy-base                1.18.5           py37hde5b4d6_0
[conda] pytorch-lightning         0.10.0                   pypi_0    pypi
[conda] pytorch-ranger            0.1.1                    pypi_0    pypi
[conda] torch                     1.8.0.dev20201009+cu110          pypi_0    pypi
[conda] torch-optimizer           0.0.1a15                 pypi_0    pypi
[conda] torch-stoi                0.1.1                    pypi_0    pypi
[conda] torchaudio                0.7.0.dev20201009          pypi_0    pypi
[conda] torchvision               0.8.0.dev20201009+cu110          pypi_0    pypi

cc @ezyang @gchanan @zou3519 @bdhirsh @ejguan @albanD @gqchen @pearu @nikitaved @anjali411 @dylanbespalko @mruberry

Metadata

Metadata

Assignees

Labels

complex_autogradhigh prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalmodule: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions