Skip to content

Move batch_norm to ATen/native, speed up#12368

Closed
t-vi wants to merge 27 commits intopytorch:masterfrom
t-vi:native_batch_norm
Closed

Move batch_norm to ATen/native, speed up#12368
t-vi wants to merge 27 commits intopytorch:masterfrom
t-vi:native_batch_norm

Conversation

@t-vi
Copy link
Copy Markdown
Collaborator

@t-vi t-vi commented Oct 5, 2018

Needless to say that I would happily separate the TensorAccessor fixes in a separate PR, as they're fixes and unrelated.

- Speed up the case of pytorch#12006 in the forward
- The backward still isn't as fast as one might hope (factor 2-3 in
the pytorch#12006 case.
- More extensive benchmarking shows not so great performance compared
  to CuDNN for cases with many channels, e.g.
  bs=8-128 / c=1024 / f=2014.
- We change the meaning of save_var to mean save_var for native batch
  norm. This appears to somewhat improve the numerical stability
  (i.e. makes TestNN.test_batch_norm_cudnn_half pass).
@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 5, 2018

It's not the prettiest code, but I used this to arrive at the conclusion that for large numbers of features, it's still slow:

import torch
import timeit
import gc
import numpy

batch_sizes = [8,16,32,64,128]
channels    = [2,32,256,1024]
features    = [(16,),
               (32,),
               (64,),
               (128,),
               (256,),
               (1024,),
               (10240,),
               (102400,),
               (32,32),
               (64,64),
               (128,128),
               (32,32,32),
               (64,64,64)]

def run_bn(input, running_mean, running_var, weight, bias, training, backward):
    out = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, training, 0.1, 1e-5)
    if backward:
        grads = torch.autograd.grad(out, [input, weight, bias], torch.ones_like(out))
    torch.cuda.synchronize()

for bs in batch_sizes:
    for c in channels:
        for f in features:
            shape = (bs, c)+f
            size = numpy.prod(shape)
            if size < 100_000_000:
                running_mean = torch.randn(c, device='cuda')
                running_var = torch.randn(c, device='cuda').exp()
                weight = torch.randn(c, device='cuda', requires_grad=True)
                bias = torch.randn(c, device='cuda', requires_grad=True)
                input = torch.randn(shape, device='cuda', requires_grad=True)
                gc.collect()
                for training in [True, False]:
                    run_bn(input, running_mean, running_var, weight, bias, training, training)
                    torch.cuda.synchronize()
                    res1 = timeit.timeit('run_bn(input, running_mean, running_var, weight, bias, training, training)', number=100, globals=globals())
                    with torch.backends.cudnn.flags(enabled=False):
                      res2 = timeit.timeit('run_bn(input, running_mean, running_var, weight, bias, training, training)', number=100, globals=globals())
                    print ('{} | {} | {} | {} | {:.2f} | {:.2f} | {:.1f}'.format(bs, c, f, training, res1, res2, res1/res2))

@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 5, 2018

Seems that the last rebase was less harmless than it looked...

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Oct 5, 2018

I know what the problem is. You need to add (aten, native_batch_norm) to aten/src/ATen/core/aten_intern ed_strings.h (and maybe others, but definitely this one).

CC @bwasti, this isn't going to be very good dev experience.

Thank you, @ezyang for the hint! I was completely lost.
Comment thread aten/src/ATen/native/Normalization.cpp Outdated
auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);

#pragma omp parallel for

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread aten/src/ATen/native/Normalization.cpp Outdated
auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);

int64_t f;
#pragma omp parallel for

This comment was marked as off-topic.

- use parallel_for instead of using OMP myself.
  Thank you @vishwakftw for the review suggestion!
- The double backward used the backward's gradient mask instead
  of its own.
@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 6, 2018

So it turns out that for fp16 even when using the variance in save_var, it doesn't work. Previously, this didn't affect the forward as we kept mean and var in accscalar_t variable within the kernel. Now that separated it into two kernels for performance reasons, we see that keeping it in fp16 isn't enough.
My conclusion is that we would would want to have save_mean/save_var as fp32 for fp16 batch norm. (This is part of what is done in CuDNN, but only for the stats that aren't handed to the user.)

@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 8, 2018

I'm not entirely sure there isn't one, but I don't see the immediate connection to the failing TestScript.test_weak_script_function.

I think the correctness is there, though there are still performance improvements, in particular for the backward. What I'm not 100% certain about is whether removing thnn_batch_norm has compatibility implications.

@ssnl ssnl self-assigned this Oct 10, 2018
@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Oct 10, 2018

Thanks a lot for doing this! I'm self-assigning so I remember to take a look. In the mean time, could you revert the submodule change?

Copy link
Copy Markdown
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that I promised you that I'd take a look today, but it's really late and I just got to my room, and don't think that I can read kernels now. I will look tomorrow. In the meantime, if it's possible, could you briefly summarize what has changed between THNN version and the current version so that the kernels are faster? Thanks!

Comment thread aten/src/ATen/native/Normalization.cpp Outdated
const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) {

using accscalar_t = at::acc_type<scalar_t, false>;
Tensor output = at::native::empty_like(input);

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread aten/src/ATen/native/Normalization.cpp Outdated
Tensor save_invstd;
const int64_t zero = 0;
if (train) {
save_mean = at::native::empty({n_input}, input.options());

This comment was marked as off-topic.

@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 12, 2018

No worries, thanks for looking!
The kernel logic itself didn't change much, I mostly switched from DeviceTensor to PackedTensorAccessors. However, I changed the splitting of the kernel.
The kernel bits that changed:

  • I amended the old "inference" kernel to the batch_norm_transform_input_kernel to deal both with
    running_mean/running_var (train=False) and save_mean/save_invstd (train=True).
    The former needs to calculate invstd from running_var, while the latter takes it directly.
    So now this is the only thing calculating out = w * (inp - mean) * invstd + b.
  • I removed the bit done by batch_norm_transform_input_kernel from the old training forward kernel
    and renamed it to batch_norm_collect_statistics_kernel. It now only computes save_mean / save_var and updates the running_mean / running_var as needed.
  • This splitting of the training forward allows a favourable thread/block-parametrisation in the
    transform_input call, which used to be one main bottleneck in the forward.

save_mean/save_std have been changed to accscalar_t because we need the precision (previously, this was in a single kernel and also kept as accscalar_t and no-one cared about the poor backward). One might ask whether it would be desirable to (optionally) support accscalar_t weights for half (as cudnn does), but I didn't do this yet.

Finally, I also changed how the statistics gathering is parallelised by swapping dimensions 0 and 2 to have the larger as 2 in the calculation of the packed_accessor. This is because the parallelisation is done based on dimension 2 only.

@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 14, 2018

Just a heads up: I've discussed kernels with the great folks working on apex sync batchnorm and implemented a backward kernel that is closer to what they use. If you wanted, I could drop in the updated backward kernel (which does similar changes as above, but in addition increases the accuracy of computation). Depending on your preference, I could add it in this PR, but by default, I'll keep it for a later PR (right now the backward kernel mostly the old THCUNN one).

@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 14, 2018

ROCm doesn't seem to like my use of sqrt:

/opt/rocm/hcc/bin/ld.lld: error: relocation R_AMDGPU_REL32_LO cannot be used against symbol sqrtf; recompile with -fPIC
>>> defined in /tmp/tmp.PsfYz7ywhH/caffe2_hip_generated_Normalization.cu.kernel.bc-gfx900.isabin
>>> referenced by /tmp/tmp.PsfYz7ywhH/caffe2_hip_generated_Normalization.cu.kernel.bc-gfx900.isabin

Somehow I feel that I shouldn't change the compilation flags fundamentally here without consulting with someone. Any ideas?

Copy link
Copy Markdown
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernels make a lot of sense! I'm a bit rusty on the warp reduction stuff so I didn't look at those code in details. I'm interested in seeing the benchmarking numbers. Could you post them?

Another thing that will speed this up a bit further is Welford algorithm, but it is out of scope of this PR.

}

// sum over NumThreads within a warp
sum = warpSum(sum);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

strides[dim -1] = 1;
}
}
// evil trick to get adjusted 2d tensors to have large dimension last

This comment was marked as off-topic.

This comment was marked as off-topic.


template <typename scalar_t, typename accscalar_t>
__global__ void batch_norm_backward_kernel(
const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input,

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

} else {
dim3 blocks_red(input.size(1));
dim3 threads_red(getNumThreads(input.size(2)));
batch_norm_collect_statistics_kernel<scalar_t, accscalar_t> <<<blocks_red, threads_red, 0, stream>>>

This comment was marked as off-topic.

This comment was marked as off-topic.

@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Oct 14, 2018

Also, I assumed that CPU kernels didn't change much, right?

@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 15, 2018

The CPU kernels are quite literally adapted from THNN.
I'll try to get a cleanish before/after benchmark (so far, I've mostly compared to cudnn).

Regarding further optimization:

  • The welford-style forward statistics is the obvious thing I've not done. (I didn't change much kernel code except that immediately required).
  • The backward would get a similar surgery (splitting into "reduction" and "pointwise") that seems, in extreme shapes (32x10x40_000) can have a rather large (10-100x) impact. Based on my discussion with the @mcarilli and @jjsjann123, my main focus for the backward kernel was getting the numerical precision a bit up for fp32 (and also we hope to share kernels between sync bn and this here).
    (I have an implementation of the kernel as an external module that looks correctish at https://gist.github.com/t-vi/82a46dc87eceae303a4f805147f82310 , but I didn't benchmark yet.)

@t-vi
Copy link
Copy Markdown
Collaborator Author

t-vi commented Oct 23, 2018

With 32 bit indexing, anything from the "real-world" benchmark taking more than 1msec seems to be not slower in native:

bs channels features train cudnn native slowness reference slowness evaluation
32 512 784 e-fb 1,0846 1,0853 1 1,0851 1  
64 512 784 e-fb 2,1476 2,1521 1 2,1661 0,99  
16 64 12544 e-fb 1,0831 1,082 1 1,1056 0,98  
16 256 3136 e-fb 1,0748 1,0739 1 1,0993 0,98  
64 1024 196 e-fb 1,1715 1,1701 1 1,1907 0,98  
32 256 3136 e-fb 2,1151 2,1143 1 2,177 0,97  

At last!

auto bs = input_cont.size(0);
auto features = input_cont.size(2);
auto input = input_cont.packed_accessor<scalar_t, 3, RestrictPtrTraits>();
AT_CHECK(cuda::detail::canUse32BitIndexMath(input_reshaped), "Input is too large for batch_norm");

This comment was marked as off-topic.

@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Oct 23, 2018

Very nice speed-up! :)

AT_HOST_DEVICE int64_t size(int64_t i) const { return sizes_[i]; }

// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

AT_HOST_DEVICE int64_t size(int64_t i) const { return sizes_[i]; }

// if index_t is not int64_t, we want to have an int64_t constructor
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>

This comment was marked as off-topic.

const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_backward_cuda_template<scalar_t, int32_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);

This comment was marked as off-topic.

This comment was marked as off-topic.

const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) {
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm", [&] {
if (cuda::detail::canUse32BitIndexMath(self)) {
return batch_norm_cuda_template<scalar_t, int32_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);

This comment was marked as off-topic.

Comment thread aten/src/ATen/native/cuda/Normalization.cu
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SsnL is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 26, 2018
Summary:
- Speed up the case of #12006 in the forward
- The backward still isn't as fast as one might hope (factor 2-3 in the #12006 case).
- More extensive benchmarking shows not so great performance compared
  to CuDNN for cases with many channels, e.g. bs=8-128 / c=1024 / f=1024.
- We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to
  maintain reasonable precision.

Needless to say that I would happily separate the TensorAccessor fixes in a separate PR, as they're fixes and unrelated.
Pull Request resolved: pytorch/pytorch#12368

Differential Revision: D10559696

Pulled By: SsnL

fbshipit-source-id: f0d0d1e0912e17b15b8fb7a2c03d0fe757598419
@li-roy li-roy mentioned this pull request Oct 26, 2018
facebook-github-bot pushed a commit that referenced this pull request Oct 26, 2018
Summary:
Fixes colliding changes in #12766 and #12368
Pull Request resolved: #13171

Differential Revision: D12109430

Pulled By: li-roy

fbshipit-source-id: f068c7df227d920aa3840762e892ce6e9c109237
houseroad added a commit to houseroad/pytorch that referenced this pull request Oct 26, 2018
facebook-github-bot pushed a commit that referenced this pull request Oct 27, 2018
Summary:
Revert #12368 since it's causing onnx related test cases failing.
#12368

SsnL The controller you requested could not be found.
Pull Request resolved: #13191

Reviewed By: BIT-silence

Differential Revision: D12810778

Pulled By: houseroad

fbshipit-source-id: 1c373b92628580097cffcd237dccc5b3d8697577
zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 27, 2018
Summary:
Revert #12368 since it's causing onnx related test cases failing.
pytorch/pytorch#12368

SsnL The controller you requested could not be found.
Pull Request resolved: pytorch/pytorch#13191

Reviewed By: BIT-silence

Differential Revision: D12810778

Pulled By: houseroad

fbshipit-source-id: 1c373b92628580097cffcd237dccc5b3d8697577
@t-vi t-vi mentioned this pull request Oct 29, 2018
facebook-github-bot pushed a commit that referenced this pull request Nov 7, 2018
Summary:
- Move batch norm from TH(CU)NN to native
- Speedups in many cases (e.g. #12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode)
- It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation.
- We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision.

Compared to the ill-fated #12368
- I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own)
- I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`.
- I have merged master
Pull Request resolved: #13263

Differential Revision: D12946254

Pulled By: SsnL

fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
- Speed up the case of pytorch#12006 in the forward
- The backward still isn't as fast as one might hope (factor 2-3 in the pytorch#12006 case).
- More extensive benchmarking shows not so great performance compared
  to CuDNN for cases with many channels, e.g. bs=8-128 / c=1024 / f=1024.
- We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to
  maintain reasonable precision.

Needless to say that I would happily separate the TensorAccessor fixes in a separate PR, as they're fixes and unrelated.
Pull Request resolved: pytorch#12368

Differential Revision: D10559696

Pulled By: SsnL

fbshipit-source-id: f0d0d1e0912e17b15b8fb7a2c03d0fe757598419
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Fixes colliding changes in pytorch#12766 and pytorch#12368
Pull Request resolved: pytorch#13171

Differential Revision: D12109430

Pulled By: li-roy

fbshipit-source-id: f068c7df227d920aa3840762e892ce6e9c109237
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…torch#13191)

Summary:
Revert pytorch#12368 since it's causing onnx related test cases failing.
pytorch#12368

SsnL The controller you requested could not be found.
Pull Request resolved: pytorch#13191

Reviewed By: BIT-silence

Differential Revision: D12810778

Pulled By: houseroad

fbshipit-source-id: 1c373b92628580097cffcd237dccc5b3d8697577
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
- Move batch norm from TH(CU)NN to native
- Speedups in many cases (e.g. pytorch#12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode)
- It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation.
- We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision.

Compared to the ill-fated pytorch#12368
- I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own)
- I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`.
- I have merged master
Pull Request resolved: pytorch#13263

Differential Revision: D12946254

Pulled By: SsnL

fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants