Skip to content

Optimize LayerNorm on CUDA #27633

@xiaomengy

Description

@xiaomengy

🚀 Feature

Improve the performance of LayerNorm operator on CUDA.

Motivation

Currently the LayerNorm CUDA implementation is reshape the input and doing BatchNorm to get the moments of input, then using addcmul for affine. This implementation is inefficient especially for the backward pass. One profiling result for a layer_norm_backward run is shown below. We can see that AddcmulBackward and mul take much more time than batch_norm_backward. So implementing a FusedLayerNorm can help improve LayerNorm performance on CUDA a lot.

-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                 Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
AddcmulBackward                      0.62%            135.160us        97.47%           21.328ms         21.328ms         33.33%           24.015ms         24.015ms         1                
mul                                  96.85%           21.193ms         96.85%           21.193ms         5.298ms          33.21%           23.927ms         5.982ms          4                
NativeBatchNormBackward              0.12%            26.088us         0.40%            87.338us         87.338us         11.04%           7.954ms          7.954ms          1                
native_batch_norm_backward           0.28%            61.250us         0.28%            61.250us         61.250us         11.03%           7.947ms          7.947ms          1                
torch::autograd::AccumulateGrad      0.23%            50.516us         0.77%            168.433us        56.144us         3.90%            2.807ms          935.628us        3                
clone                                0.54%            117.917us        0.54%            117.917us        39.306us         3.87%            2.787ms          928.842us        3                
sum                                  0.58%            127.655us        0.58%            127.655us        63.827us         3.49%            2.517ms          1.259ms          2                
ViewBackward                         0.18%            39.172us         0.46%            101.337us        50.669us         0.05%            34.141us         17.070us         2                
torch::autograd::GraphRoot           0.14%            29.563us         0.14%            29.563us         29.563us         0.03%            21.952us         21.952us         1                
reshape                              0.19%            42.126us         0.28%            62.165us         31.082us         0.03%            20.605us         10.303us         2                
as_strided                           0.09%            20.039us         0.09%            20.039us         10.020us         0.01%            7.137us          3.568us          2                
view                                 0.18%            39.994us         0.18%            39.994us         19.997us         0.01%            6.977us          3.488us          2                
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  

Pitch

Alternatives

In NVIDIA's apex lib, there is a FusedLayerNorm operator which is much better than the current unfused implementation. However, there is some global_variables in apex.FusedLayerNorm, so we cannot use it directly with jit.

#26201 tried to port apex.FusedLayerNorm to replace the current LayerNorm.

Additional context

cc @ngimel @dzhulgakov @zheng-xq

Metadata

Metadata

Assignees

Labels

module: cudaRelated to torch.cuda, and CUDA support in generalmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions