Move batch_norm to ATen/native, speed up#12368
Conversation
- Speed up the case of pytorch#12006 in the forward - The backward still isn't as fast as one might hope (factor 2-3 in the pytorch#12006 case. - More extensive benchmarking shows not so great performance compared to CuDNN for cases with many channels, e.g. bs=8-128 / c=1024 / f=2014. - We change the meaning of save_var to mean save_var for native batch norm. This appears to somewhat improve the numerical stability (i.e. makes TestNN.test_batch_norm_cudnn_half pass).
|
It's not the prettiest code, but I used this to arrive at the conclusion that for large numbers of features, it's still slow: import torch
import timeit
import gc
import numpy
batch_sizes = [8,16,32,64,128]
channels = [2,32,256,1024]
features = [(16,),
(32,),
(64,),
(128,),
(256,),
(1024,),
(10240,),
(102400,),
(32,32),
(64,64),
(128,128),
(32,32,32),
(64,64,64)]
def run_bn(input, running_mean, running_var, weight, bias, training, backward):
out = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, training, 0.1, 1e-5)
if backward:
grads = torch.autograd.grad(out, [input, weight, bias], torch.ones_like(out))
torch.cuda.synchronize()
for bs in batch_sizes:
for c in channels:
for f in features:
shape = (bs, c)+f
size = numpy.prod(shape)
if size < 100_000_000:
running_mean = torch.randn(c, device='cuda')
running_var = torch.randn(c, device='cuda').exp()
weight = torch.randn(c, device='cuda', requires_grad=True)
bias = torch.randn(c, device='cuda', requires_grad=True)
input = torch.randn(shape, device='cuda', requires_grad=True)
gc.collect()
for training in [True, False]:
run_bn(input, running_mean, running_var, weight, bias, training, training)
torch.cuda.synchronize()
res1 = timeit.timeit('run_bn(input, running_mean, running_var, weight, bias, training, training)', number=100, globals=globals())
with torch.backends.cudnn.flags(enabled=False):
res2 = timeit.timeit('run_bn(input, running_mean, running_var, weight, bias, training, training)', number=100, globals=globals())
print ('{} | {} | {} | {} | {:.2f} | {:.2f} | {:.1f}'.format(bs, c, f, training, res1, res2, res1/res2)) |
|
Seems that the last rebase was less harmless than it looked... |
|
I know what the problem is. You need to add CC @bwasti, this isn't going to be very good dev experience. |
Thank you, @ezyang for the hint! I was completely lost.
| auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean); | ||
| auto running_var_a = conditional_accessor_1d<scalar_t>(running_var); | ||
|
|
||
| #pragma omp parallel for |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto running_var_a = conditional_accessor_1d<scalar_t>(running_var); | ||
|
|
||
| int64_t f; | ||
| #pragma omp parallel for |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
- use parallel_for instead of using OMP myself. Thank you @vishwakftw for the review suggestion! - The double backward used the backward's gradient mask instead of its own.
|
So it turns out that for fp16 even when using the variance in save_var, it doesn't work. Previously, this didn't affect the forward as we kept mean and var in accscalar_t variable within the kernel. Now that separated it into two kernels for performance reasons, we see that keeping it in fp16 isn't enough. |
Thanks, ngimel for the advice.
|
I'm not entirely sure there isn't one, but I don't see the immediate connection to the failing TestScript.test_weak_script_function. I think the correctness is there, though there are still performance improvements, in particular for the backward. What I'm not 100% certain about is whether removing thnn_batch_norm has compatibility implications. |
|
Thanks a lot for doing this! I'm self-assigning so I remember to take a look. In the mean time, could you revert the submodule change? |
Oops, thank you, Simon, for pointing this out.
ssnl
left a comment
There was a problem hiding this comment.
I know that I promised you that I'd take a look today, but it's really late and I just got to my room, and don't think that I can read kernels now. I will look tomorrow. In the meantime, if it's possible, could you briefly summarize what has changed between THNN version and the current version so that the kernels are faster? Thanks!
| const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) { | ||
|
|
||
| using accscalar_t = at::acc_type<scalar_t, false>; | ||
| Tensor output = at::native::empty_like(input); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Tensor save_invstd; | ||
| const int64_t zero = 0; | ||
| if (train) { | ||
| save_mean = at::native::empty({n_input}, input.options()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
No worries, thanks for looking!
Finally, I also changed how the statistics gathering is parallelised by swapping dimensions 0 and 2 to have the larger as 2 in the calculation of the packed_accessor. This is because the parallelisation is done based on dimension 2 only. |
|
Just a heads up: I've discussed kernels with the great folks working on apex sync batchnorm and implemented a backward kernel that is closer to what they use. If you wanted, I could drop in the updated backward kernel (which does similar changes as above, but in addition increases the accuracy of computation). Depending on your preference, I could add it in this PR, but by default, I'll keep it for a later PR (right now the backward kernel mostly the old THCUNN one). |
Thank you, Simon, for the hint!
|
ROCm doesn't seem to like my use of sqrt: Somehow I feel that I shouldn't change the compilation flags fundamentally here without consulting with someone. Any ideas? |
ssnl
left a comment
There was a problem hiding this comment.
Kernels make a lot of sense! I'm a bit rusty on the warp reduction stuff so I didn't look at those code in details. I'm interested in seeing the benchmarking numbers. Could you post them?
Another thing that will speed this up a bit further is Welford algorithm, but it is out of scope of this PR.
| } | ||
|
|
||
| // sum over NumThreads within a warp | ||
| sum = warpSum(sum); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| strides[dim -1] = 1; | ||
| } | ||
| } | ||
| // evil trick to get adjusted 2d tensors to have large dimension last |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| template <typename scalar_t, typename accscalar_t> | ||
| __global__ void batch_norm_backward_kernel( | ||
| const PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits> input, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| } else { | ||
| dim3 blocks_red(input.size(1)); | ||
| dim3 threads_red(getNumThreads(input.size(2))); | ||
| batch_norm_collect_statistics_kernel<scalar_t, accscalar_t> <<<blocks_red, threads_red, 0, stream>>> |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Also, I assumed that CPU kernels didn't change much, right? |
|
The CPU kernels are quite literally adapted from THNN. Regarding further optimization:
|
|
With 32 bit indexing, anything from the "real-world" benchmark taking more than 1msec seems to be not slower in native:
At last! |
| auto bs = input_cont.size(0); | ||
| auto features = input_cont.size(2); | ||
| auto input = input_cont.packed_accessor<scalar_t, 3, RestrictPtrTraits>(); | ||
| AT_CHECK(cuda::detail::canUse32BitIndexMath(input_reshaped), "Input is too large for batch_norm"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Very nice speed-up! :) |
| AT_HOST_DEVICE int64_t size(int64_t i) const { return sizes_[i]; } | ||
|
|
||
| // if index_t is not int64_t, we want to have an int64_t constructor | ||
| template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type> |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| AT_HOST_DEVICE int64_t size(int64_t i) const { return sizes_[i]; } | ||
|
|
||
| // if index_t is not int64_t, we want to have an int64_t constructor | ||
| template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type> |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array<bool,3> grad_input_mask) { | ||
| return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] { | ||
| if (cuda::detail::canUse32BitIndexMath(self)) { | ||
| return batch_norm_backward_cuda_template<scalar_t, int32_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) { | ||
| return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm", [&] { | ||
| if (cuda::detail::canUse32BitIndexMath(self)) { | ||
| return batch_norm_cuda_template<scalar_t, int32_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
SsnL is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - Speed up the case of #12006 in the forward - The backward still isn't as fast as one might hope (factor 2-3 in the #12006 case). - More extensive benchmarking shows not so great performance compared to CuDNN for cases with many channels, e.g. bs=8-128 / c=1024 / f=1024. - We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision. Needless to say that I would happily separate the TensorAccessor fixes in a separate PR, as they're fixes and unrelated. Pull Request resolved: pytorch/pytorch#12368 Differential Revision: D10559696 Pulled By: SsnL fbshipit-source-id: f0d0d1e0912e17b15b8fb7a2c03d0fe757598419
This reverts commit dc211c7.
Summary: Revert #12368 since it's causing onnx related test cases failing. #12368 SsnL The controller you requested could not be found. Pull Request resolved: #13191 Reviewed By: BIT-silence Differential Revision: D12810778 Pulled By: houseroad fbshipit-source-id: 1c373b92628580097cffcd237dccc5b3d8697577
Summary: Revert #12368 since it's causing onnx related test cases failing. pytorch/pytorch#12368 SsnL The controller you requested could not be found. Pull Request resolved: pytorch/pytorch#13191 Reviewed By: BIT-silence Differential Revision: D12810778 Pulled By: houseroad fbshipit-source-id: 1c373b92628580097cffcd237dccc5b3d8697577
Summary: - Move batch norm from TH(CU)NN to native - Speedups in many cases (e.g. #12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode) - It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation. - We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision. Compared to the ill-fated #12368 - I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own) - I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`. - I have merged master Pull Request resolved: #13263 Differential Revision: D12946254 Pulled By: SsnL fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
Summary: - Speed up the case of pytorch#12006 in the forward - The backward still isn't as fast as one might hope (factor 2-3 in the pytorch#12006 case). - More extensive benchmarking shows not so great performance compared to CuDNN for cases with many channels, e.g. bs=8-128 / c=1024 / f=1024. - We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision. Needless to say that I would happily separate the TensorAccessor fixes in a separate PR, as they're fixes and unrelated. Pull Request resolved: pytorch#12368 Differential Revision: D10559696 Pulled By: SsnL fbshipit-source-id: f0d0d1e0912e17b15b8fb7a2c03d0fe757598419
Summary: Fixes colliding changes in pytorch#12766 and pytorch#12368 Pull Request resolved: pytorch#13171 Differential Revision: D12109430 Pulled By: li-roy fbshipit-source-id: f068c7df227d920aa3840762e892ce6e9c109237
…torch#13191) Summary: Revert pytorch#12368 since it's causing onnx related test cases failing. pytorch#12368 SsnL The controller you requested could not be found. Pull Request resolved: pytorch#13191 Reviewed By: BIT-silence Differential Revision: D12810778 Pulled By: houseroad fbshipit-source-id: 1c373b92628580097cffcd237dccc5b3d8697577
Summary: - Move batch norm from TH(CU)NN to native - Speedups in many cases (e.g. pytorch#12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode) - It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation. - We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision. Compared to the ill-fated pytorch#12368 - I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own) - I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`. - I have merged master Pull Request resolved: pytorch#13263 Differential Revision: D12946254 Pulled By: SsnL fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
to CuDNN for cases with many channels, e.g. bs=8-128 / c=1024 / f=1024.
maintain reasonable precision.
Needless to say that I would happily separate the TensorAccessor fixes in a separate PR, as they're fixes and unrelated.