Skip to content

Optimize GroupNorm in PyTorch #28201

@xiaomengy

Description

@xiaomengy

🚀 Feature

Improve the performance of GroupNorm operator.

Motivation

Similar as #27633, the current GroupNorm implementation is reshaping the input and doing BatchNorm to get the moments of input, then using addcmul for affine. This implementation is inefficient for both CPU and CUDA.

The performance benchmark for input shape = [128, 256, 28, 28], num_groups = 32 is shown below.

GroupNorm forward: 210.97913278000487ms
--------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                        Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  
--------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
native_batch_norm           68.00%           135.920s         68.00%           135.920s         135.920ms        NaN              0.000us          0.000us          1000             
addcmul                     31.94%           63.841s          31.94%           63.841s          63.841ms         NaN              0.000us          0.000us          1000             
view                        0.03%            69.572ms         0.03%            69.572ms         17.393us         NaN              0.000us          0.000us          4000             
group_norm                  0.01%            29.095ms         100.00%          199.870s         199.870ms        NaN              0.000us          0.000us          1000             
_batch_norm_impl_index      0.00%            5.179ms          68.01%           135.925s         135.925ms        NaN              0.000us          0.000us          1000             
batch_norm                  0.00%            4.594ms          68.01%           135.929s         135.929ms        NaN              0.000us          0.000us          1000             
contiguous                  0.00%            1.815ms          0.00%            1.815ms          1.815us          NaN              0.000us          0.000us          1000             
--------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 199.870s
CUDA time total: 0.000us

GroupNorm backward: 498.55486265799846ms
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                 Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
mul                                  46.02%           224.071s         46.02%           224.071s         56.018ms         NaN              0.000us          0.000us          4000             
native_batch_norm_backward           43.86%           213.559s         43.86%           213.559s         213.559ms        NaN              0.000us          0.000us          1000             
sum                                  3.91%            19.040s          3.91%            19.040s          9.520ms          NaN              0.000us          0.000us          2000             
add_                                 3.28%            15.951s          3.28%            15.951s          5.317ms          NaN              0.000us          0.000us          3000             
AddcmulBackward                      1.95%            9.501s           47.97%           233.571s         233.571ms        NaN              0.000us          0.000us          1000             
torch::autograd::AccumulateGrad      0.97%            4.737s           4.25%            20.687s          6.896ms          NaN              0.000us          0.000us          3000             
as_strided                           0.00%            20.049ms         0.00%            20.049ms         5.012us          NaN              0.000us          0.000us          4000             
NativeBatchNormBackward              0.00%            13.951ms         43.86%           213.573s         213.573ms        NaN              0.000us          0.000us          1000             
reshape                              0.00%            13.034ms         0.01%            33.083ms         8.271us          NaN              0.000us          0.000us          4000             
ViewBackward                         0.00%            9.419ms          0.01%            42.502ms         10.626us         NaN              0.000us          0.000us          4000             
torch::autograd::GraphRoot           0.00%            1.189ms          0.00%            1.189ms          1.189us          NaN              0.000us          0.000us          1000             
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 486.916s
CUDA time total: 0.000us

And the performance benchmark for input shape = [256, 512, 56, 56], num_groups = 32 is shown below.

GroupNorm forward: 11.333400868010358ms
--------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                        Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  
--------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
group_norm                  23.06%           56.891ms         100.00%          246.741ms        246.741us        31.20%           11.369s          11.369ms         1000             
batch_norm                  8.59%            21.194ms         35.82%           88.382ms         88.382us         18.89%           6.884s           6.884ms          1000             
_batch_norm_impl_index      6.50%            16.045ms         27.23%           67.189ms         67.189us         18.86%           6.872s           6.872ms          1000             
native_batch_norm           20.73%           51.143ms         20.73%           51.143ms         51.143us         18.83%           6.861s           6.861ms          1000             
addcmul                     15.17%           37.434ms         15.17%           37.434ms         37.434us         12.13%           4.419s           4.419ms          1000             
view                        21.76%           53.695ms         21.76%           53.695ms         13.424us         0.06%            21.595ms         5.399us          4000             
contiguous                  4.19%            10.339ms         4.19%            10.339ms         10.339us         0.03%            9.650ms          9.650us          1000             
--------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 246.741ms
CUDA time total: 36.436s

GroupNorm backward: 42.1425356430118ms
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                 Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
AddcmulBackward                      10.08%           69.875ms         36.22%           251.046ms        251.046us        24.64%           19.492s          19.492ms         1000             
mul                                  26.14%           181.171ms        26.14%           181.171ms        45.293us         24.60%           19.460s          4.865ms          4000             
NativeBatchNormBackward              3.44%            23.832ms         9.82%            68.072ms         68.072us         14.23%           11.261s          11.261ms         1000             
native_batch_norm_backward           6.38%            44.240ms         6.38%            44.240ms         44.240us         14.23%           11.255s          11.255ms         1000             
torch::autograd::AccumulateGrad      7.14%            49.495ms         16.52%           114.485ms        38.162us         8.02%            6.343s           2.114ms          3000             
add_                                 9.38%            64.990ms         9.38%            64.990ms         21.663us         8.00%            6.326s           2.109ms          3000             
sum                                  11.71%           81.163ms         11.71%           81.163ms         40.581us         6.15%            4.863s           2.431ms          2000             
ViewBackward                         9.72%            67.398ms         23.92%           165.801ms        41.450us         0.07%            57.930ms         14.482us         4000             
reshape                              8.35%            57.903ms         14.20%           98.403ms         24.601us         0.04%            35.134ms         8.783us          4000             
as_strided                           5.84%            40.500ms         5.84%            40.500ms         10.125us         0.01%            11.111ms         2.778us          4000             
torch::autograd::GraphRoot           1.81%            12.526ms         1.81%            12.526ms         12.526us         0.01%            8.698ms          8.698us          1000             
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 693.093ms
CUDA time total: 79.113s

We can see that for both CPU and GPU version of GroupNorm, using BatchNorm with addcmul make things slow especially for backward pass. Actually on CPU side, since BatchNorm for inference is a affine function and can be fused with Conv, it makes the GroupNorm very slow when using BatchNorm on CPU for inference.

Pitch

To implement a optimized version of GroupNorm which fused everything together.

Alternatives

Additional context

cc @dzhulgakov @ngimel @ppwwyyxx

Metadata

Metadata

Assignees

Labels

triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions