Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[BUG] The wrong gradient of Batch Norm when grad_req = add #18499

@wkcn

Description

@wkcn

Description

Hi there, we found that the current implementation of batch norm layer does support grad_req = add. If grad_req is set to add, the gradient of input data is not accumulated. Besides the gradient of gamma and beta are not assigned to any value by mistake.

To Reproduce

import mxnet as mx
from mxnet.gluon import nn

N = 1
C = 3
H = W = 2
block = nn.BatchNorm() 
block.collect_params().initialize()
block.collect_params().setattr('grad_req', 'add')

x = mx.nd.arange(N*C*H*W).reshape((N, C, H, W))
x.attach_grad()
for i in range(3):
    with mx.autograd.record():
        y = block(x)
        loss = (y * y).sum() 
    loss.backward()
print(x.grad, block.gamma.grad(), block.beta.grad())

It outputs the following message:
mxnet-2.0.0b20200421 installed by pip

[[[[-1.8979003e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.8979003e-05]]

  [[-1.8979003e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.8979003e-05]]

  [[-1.8979003e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.8979003e-05]]]]
<NDArray 1x3x2x2 @cpu(0)> 
[7.999936 7.999936 7.999936]
<NDArray 3 @cpu(0)> 
[0. 0. 0.]
<NDArray 3 @cpu(0)>

MXNet 1.6 installed by pip --pre

[[[[-1.9192250e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.9192250e-05]]

  [[-1.9192250e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.9192250e-05]]

  [[-1.9192250e-05 -6.3974167e-06]
   [ 6.3974167e-06  1.9192250e-05]]]]
<NDArray 1x3x2x2 @cpu(0)> 
[0. 0. 0.]
<NDArray 3 @cpu(0)> 
[0. 0. 0.]
<NDArray 3 @cpu(0)>

The correct result should be

[[[[-5.8216e-05, -1.9352e-05],
   [ 1.9352e-05,  5.8216e-05]],

   [[-5.8216e-05, -1.9352e-05],
   [ 1.9352e-05,  5.8216e-05]],

   [[-5.8216e-05, -1.9352e-05],
   [ 1.9352e-05,  5.8216e-05]]]]

[23.9998, 23.9998, 23.9998]

[0., 0., 0.]

The several values are the gradients of the input data, gamma, beta individually. The gradients are wrong.

Environment

mxnet-2.0.0b20200421 installed by pip
I could not run the latest version(mxnet-2.0.0b20200516) of MXNet 2.0 on my laptop since libopenblas.so.0 is not found : (

----------Python Info----------
Version      : 3.8.3
Compiler     : GCC 10.1.0
Build        : ('default', 'May 17 2020 18:15:42')
Arch         : ('64bit', 'ELF')
------------Pip Info-----------
Version      : 20.0.2
Directory    : /usr/lib/python3.8/site-packages/pip
----------MXNet Info-----------
Version      : 2.0.0
Directory    : /usr/lib/python3.8/site-packages/mxnet
Hashtag not found. Not installed from pre-built package.
----------System Info----------
Platform     : Linux-5.6.15-arch1-1-x86_64-with-glibc2.2.5
system       : Linux
node         : MiraiT
release      : 5.6.15-arch1-1
version      : #1 SMP PREEMPT Wed, 27 May 2020 23:42:26 +0000
----------Hardware Info----------
machine      : x86_64
processor    : 
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions