This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
[BUG] The wrong gradient of Batch Norm when grad_req = add #18499
Copy link
Copy link
Closed
Description
Description
Hi there, we found that the current implementation of batch norm layer does support grad_req = add. If grad_req is set to add, the gradient of input data is not accumulated. Besides the gradient of gamma and beta are not assigned to any value by mistake.
To Reproduce
import mxnet as mx
from mxnet.gluon import nn
N = 1
C = 3
H = W = 2
block = nn.BatchNorm()
block.collect_params().initialize()
block.collect_params().setattr('grad_req', 'add')
x = mx.nd.arange(N*C*H*W).reshape((N, C, H, W))
x.attach_grad()
for i in range(3):
with mx.autograd.record():
y = block(x)
loss = (y * y).sum()
loss.backward()
print(x.grad, block.gamma.grad(), block.beta.grad())It outputs the following message:
mxnet-2.0.0b20200421 installed by pip
[[[[-1.8979003e-05 -6.3974167e-06]
[ 6.3974167e-06 1.8979003e-05]]
[[-1.8979003e-05 -6.3974167e-06]
[ 6.3974167e-06 1.8979003e-05]]
[[-1.8979003e-05 -6.3974167e-06]
[ 6.3974167e-06 1.8979003e-05]]]]
<NDArray 1x3x2x2 @cpu(0)>
[7.999936 7.999936 7.999936]
<NDArray 3 @cpu(0)>
[0. 0. 0.]
<NDArray 3 @cpu(0)>
MXNet 1.6 installed by pip --pre
[[[[-1.9192250e-05 -6.3974167e-06]
[ 6.3974167e-06 1.9192250e-05]]
[[-1.9192250e-05 -6.3974167e-06]
[ 6.3974167e-06 1.9192250e-05]]
[[-1.9192250e-05 -6.3974167e-06]
[ 6.3974167e-06 1.9192250e-05]]]]
<NDArray 1x3x2x2 @cpu(0)>
[0. 0. 0.]
<NDArray 3 @cpu(0)>
[0. 0. 0.]
<NDArray 3 @cpu(0)>
The correct result should be
[[[[-5.8216e-05, -1.9352e-05],
[ 1.9352e-05, 5.8216e-05]],
[[-5.8216e-05, -1.9352e-05],
[ 1.9352e-05, 5.8216e-05]],
[[-5.8216e-05, -1.9352e-05],
[ 1.9352e-05, 5.8216e-05]]]]
[23.9998, 23.9998, 23.9998]
[0., 0., 0.]
The several values are the gradients of the input data, gamma, beta individually. The gradients are wrong.
Environment
mxnet-2.0.0b20200421 installed by pip
I could not run the latest version(mxnet-2.0.0b20200516) of MXNet 2.0 on my laptop since libopenblas.so.0 is not found : (
----------Python Info----------
Version : 3.8.3
Compiler : GCC 10.1.0
Build : ('default', 'May 17 2020 18:15:42')
Arch : ('64bit', 'ELF')
------------Pip Info-----------
Version : 20.0.2
Directory : /usr/lib/python3.8/site-packages/pip
----------MXNet Info-----------
Version : 2.0.0
Directory : /usr/lib/python3.8/site-packages/mxnet
Hashtag not found. Not installed from pre-built package.
----------System Info----------
Platform : Linux-5.6.15-arch1-1-x86_64-with-glibc2.2.5
system : Linux
node : MiraiT
release : 5.6.15-arch1-1
version : #1 SMP PREEMPT Wed, 27 May 2020 23:42:26 +0000
----------Hardware Info----------
machine : x86_64
processor :
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian