Issue description
I use torch.nn.BatchNorm1d in my code. When the batchsize of the input exceeds a threshold, BN gets very slow. This problem can be produced on my machine with the code (test_bn.py) provided below.
For example, in the following output I test it with input of size (5x20000, 256), (6x20000, 256), (7x20000, 256) and (8x20000, 256). The magnitude of the running time changes after 6.
~$ CUDA_VISIBLE_DEVICES=5 python test_bn.py 5
forward: 0.007s
backward: 0.019s
total: 0.026s
~$ CUDA_VISIBLE_DEVICES=5 python test_bn.py 6
forward: 0.007s
backward: 0.022s
total: 0.030s
~$ CUDA_VISIBLE_DEVICES=5 python test_bn.py 7
forward: 0.135s
backward: 0.102s
total: 0.237s
~$ CUDA_VISIBLE_DEVICES=5 python test_bn.py 8
forward: 0.155s
backward: 0.117s
total: 0.271s
Code example
# test_bn.py
import torch
import time
import sys
bs = int(sys.argv[1])
num_features = 256
shape = torch.Size((bs * 20000, num_features))
input = torch.cuda.FloatTensor(shape)
torch.randn(shape, out=input)
input.requires_grad = True
bn = torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
bn.cuda()
N = 30
forward_time = 0
backward_time = 0
for i in range(N):
torch.cuda.synchronize()
start = time.time()
output = bn(input)
output = output.mean()
torch.cuda.synchronize()
forward_time += time.time() - start
torch.cuda.synchronize()
start = time.time()
output.backward()
torch.cuda.synchronize()
backward_time += time.time() - start
print('forward: %f', forward_time / N)
print('backward: %f', backward_time / N)
print('total: %f', (forward_time + backward_time) / N)
System Info
Collecting environment information...
PyTorch version: 1.0.0.dev20180921
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609
CMake version: version 3.6.3
Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 8.0.44
GPU models and configuration:
GPU 0: TITAN X (Pascal)
GPU 1: TITAN X (Pascal)
GPU 2: TITAN X (Pascal)
GPU 3: TITAN X (Pascal)
GPU 4: TITAN X (Pascal)
GPU 5: TITAN X (Pascal)
GPU 6: TITAN X (Pascal)
GPU 7: TITAN X (Pascal)
Nvidia driver version: 384.111
cuDNN version: Probably one of the following:
/usr/local/MATLAB/R2016a/bin/glnxa64/libcudnn.so.7.0.64
/usr/local/bak-cuda-9.0/lib64/libcudnn.so
/usr/local/bak-cuda-9.0/lib64/libcudnn.so.7
/usr/local/bak-cuda-9.0/lib64/libcudnn.so.7.0.5
/usr/local/bak-cuda-9.0/lib64/libcudnn_static.a
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn.so
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.0.5
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn_static.a
/usr/local/cuda-8.0/lib64/libcudnn.so
/usr/local/cuda-8.0/lib64/libcudnn.so.7
/usr/local/cuda-8.0/lib64/libcudnn.so.7.0.5
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/lib/libcudnn.so.5.1.5
/usr/local/lib/libcudnn_static.a
/usr/local/lib/python2.7/dist-packages/torch/lib/libcudnn-900fef33.so.7.0.5
Versions of relevant libraries:
[pip] msgpack-numpy (0.4.3.1)
[pip] numpy (1.15.0)
[pip] torch (1.0.0.dev20180921)
[pip] torch-cluster (1.1.5)
[pip] torch-scatter (1.0.4)
[pip] torch-sparse (0.2.0)
[pip] torch-spline-conv (1.0.4)
[pip] torchfile (0.1.0)
[pip] torchnet (0.0.2)
[conda] cuda80 1.0 0 soumith
[conda] magma-cuda80 2.3.0 1 pytorch
[conda] nccl2 1.0 0 pytorch
[conda] pytorch-nightly 1.0.0.dev20180921 py3.5_cuda9.0.176_cudnn7.1.2_0 pytorch
[conda] torch 0.5.0a0+c8b246a
[conda] torch-cluster 1.1.5
[conda] torch-scatter 1.0.4
[conda] torch-sparse 0.2.0
[conda] torch-spline-conv 1.0.4
[conda] torchfile 0.1.0
[conda] torchnet 0.0.2
Issue description
I use
torch.nn.BatchNorm1din my code. When the batchsize of the input exceeds a threshold, BN gets very slow. This problem can be produced on my machine with the code (test_bn.py) provided below.For example, in the following output I test it with input of size (5x20000, 256), (6x20000, 256), (7x20000, 256) and (8x20000, 256). The magnitude of the running time changes after 6.
Code example
System Info
Collecting environment information...
PyTorch version: 1.0.0.dev20180921
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609
CMake version: version 3.6.3
Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 8.0.44
GPU models and configuration:
GPU 0: TITAN X (Pascal)
GPU 1: TITAN X (Pascal)
GPU 2: TITAN X (Pascal)
GPU 3: TITAN X (Pascal)
GPU 4: TITAN X (Pascal)
GPU 5: TITAN X (Pascal)
GPU 6: TITAN X (Pascal)
GPU 7: TITAN X (Pascal)
Nvidia driver version: 384.111
cuDNN version: Probably one of the following:
/usr/local/MATLAB/R2016a/bin/glnxa64/libcudnn.so.7.0.64
/usr/local/bak-cuda-9.0/lib64/libcudnn.so
/usr/local/bak-cuda-9.0/lib64/libcudnn.so.7
/usr/local/bak-cuda-9.0/lib64/libcudnn.so.7.0.5
/usr/local/bak-cuda-9.0/lib64/libcudnn_static.a
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn.so
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.0.5
/usr/local/bak-cuda-9.0/targets/x86_64-linux/lib/libcudnn_static.a
/usr/local/cuda-8.0/lib64/libcudnn.so
/usr/local/cuda-8.0/lib64/libcudnn.so.7
/usr/local/cuda-8.0/lib64/libcudnn.so.7.0.5
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/lib/libcudnn.so.5.1.5
/usr/local/lib/libcudnn_static.a
/usr/local/lib/python2.7/dist-packages/torch/lib/libcudnn-900fef33.so.7.0.5
Versions of relevant libraries:
[pip] msgpack-numpy (0.4.3.1)
[pip] numpy (1.15.0)
[pip] torch (1.0.0.dev20180921)
[pip] torch-cluster (1.1.5)
[pip] torch-scatter (1.0.4)
[pip] torch-sparse (0.2.0)
[pip] torch-spline-conv (1.0.4)
[pip] torchfile (0.1.0)
[pip] torchnet (0.0.2)
[conda] cuda80 1.0 0 soumith
[conda] magma-cuda80 2.3.0 1 pytorch
[conda] nccl2 1.0 0 pytorch
[conda] pytorch-nightly 1.0.0.dev20180921 py3.5_cuda9.0.176_cudnn7.1.2_0 pytorch
[conda] torch 0.5.0a0+c8b246a
[conda] torch-cluster 1.1.5
[conda] torch-scatter 1.0.4
[conda] torch-sparse 0.2.0
[conda] torch-spline-conv 1.0.4
[conda] torchfile 0.1.0
[conda] torchnet 0.0.2