Skip to content

Wrong gradient for torch.norm(x, p=float('inf')) when input tensor has non-unique max values #41779

@yku12cn

Description

@yku12cn

🐛 Bug

Given the input tensor:

tensor([[ 9.,  2.,  9.],
        [-2., -3., -4.],
        [ 7.,  8., -9.]])

after norm(x, p=float('inf')), and backward(), the x.grad is:

tensor([[ 1.,  0.,  1.],
        [-0., -0., -0.],
        [ 0.,  0., -1.]])

While the correct grad should be:

tensor([[ 0.3333,  0.0000,  0.3333],
        [-0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.3333]])

To Reproduce

Run following code:

with torch.enable_grad():
    a = torch.tensor([
        [9., 2., 9.],
        [-2., -3., -4.],
        [7., 8., -9.],
    ], requires_grad=True)
    b = torch.norm(a, p=float('inf'))
    b.backward()
    print(a.grad)

Expected behavior

tensor([[ 0.3333,  0.0000,  0.3333],
        [-0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.3333]])

Environment

PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Microsoft Windows 10 Pro
GCC version: Could not collect
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.0
[pip3] torchvision==0.6.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 h74a9793_0
[conda] mkl 2020.1 216
[conda] mkl-service 2.3.0 py37hb782905_0
[conda] mkl_fft 1.1.0 py37h45dec08_0
[conda] mkl_random 1.1.1 py37h47e9c7a_0
[conda] numpy 1.18.5 py37h6530119_0
[conda] numpy-base 1.18.5 py37hc3f5095_0
[conda] pytorch 1.5.0 py3.7_cuda102_cudnn7_0 pytorch
[conda] torchvision 0.6.0 py37_cu102 pytorch

Additional context

Mathematically speaking, the gradient of l_n norm should converge to the gradient of l_∞ while n grow.
with the following code

with torch.enable_grad():
    a = torch.tensor([
        [9., 2., 9.],
        [-2., -3., -4.],
        [7., 8., -9.],
    ], requires_grad=True)
    b = torch.norm(a, p=30)
    b.backward()
    print(a.grad)

Since we feed a rather large 'p', its output becomes:

tensor([[ 3.4248e-01,  3.9037e-20,  3.4248e-01],
        [-3.9037e-20, -4.9903e-15, -2.0958e-11],
        [ 2.3413e-04,  1.1252e-02, -3.4248e-01]])

which is much closer to

tensor([[ 0.3333,  0.0000,  0.3333],
        [-0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.3333]])

rather than

tensor([[ 1.,  0.,  1.],
        [-0., -0., -0.],
        [ 0.,  0., -1.]])

In another word, '∂torch.norm(x, n)/∂x' does not converge to '∂torch.norm(x, float('inf'))/∂x' as n grow while there is non-unique max values in the input tensor.

cc @ezyang @ssnl @albanD @zou3519 @gqchen

Metadata

Metadata

Assignees

Labels

module: autogradRelated to torch.autograd, and the autograd engine in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions