Commit 8bf9e99
[pytorch][cuda] Some speedup on depth wise convolution 2D forward (#125362)
This PR does a few things:
- Adds a generic implementation for `conv_depthwise2d` when the filter size is non standard. This implementation works faster because it doesn't do edge condition checks inside the innermost loops. We avoid the checks by calculating the boundaries ahead of the loop.
- Hints to nvcc to minimize the register usage so that we squeeze more memory bandwidth
- Adds filter size 5 as a common size where we can use the template implementation to improve unrolling and generate more efficient code
The implementation doesn't completely fix the issue described in #18631. For that we need to rewrite the kernel using the suggestions described in the issue chat. This PR uses the same order of accessing the tensor as before but just removes overhead instructions in the inner loops to get the speedup.
Before:
```
conv2d-performance:
B C iH iW kH kW native (cpu) conv2d (cuda) conv2d-fp16 (cuda)
0 8.0 64.0 1024.0 1008.0 5.0 5.0 149.052643 24.982176 3.236192
1 8.0 64.0 1008.0 1008.0 5.0 5.0 150.810333 24.643536 3.237760
2 4.0 48.0 720.0 539.0 6.0 1.0 15.747776 2.636320 1.788672
3 4.0 120.0 379.0 283.0 6.0 1.0 12.234080 1.791712 1.231360
4 4.0 32.0 713.0 532.0 6.0 1.0 10.362272 1.731584 1.170544
5 4.0 3.0 712.0 542.0 31.0 31.0 24.965248 3.406304 4.165440
6 4.0 120.0 379.0 288.0 1.0 6.0 10.772512 1.215616 0.939936
7 1024.0 384.0 1.0 928.0 1.0 3.0 60.051582 7.594256 2.861344
8 4.0 24.0 687.0 512.0 6.0 1.0 10.231536 1.196704 0.818432
9 96.0 96.0 112.0 112.0 5.0 5.0 21.025631 5.110096 0.715520
10 96.0 80.0 56.0 56.0 5.0 5.0 9.730064 1.016080 0.207424
11 64.0 128.0 64.0 84.0 3.0 3.0 18.759552 0.616736 0.200832
12 16.0 960.0 7.0 7.0 5.0 5.0 0.274880 0.020288 0.014688
13 16.0 64.0 112.0 112.0 3.0 3.0 6.425696 0.189088 0.053728
```
After
```
B C iH iW kH kW native (cpu) conv2d (cuda) conv2d-fp16 (cuda)
0 8.0 64.0 1024.0 1008.0 5.0 5.0 122.534370 12.915648 3.269936
1 8.0 64.0 1008.0 1008.0 5.0 5.0 126.026978 12.826848 3.236608
2 4.0 48.0 720.0 539.0 6.0 1.0 14.488160 1.803424 1.794368
3 4.0 120.0 379.0 283.0 6.0 1.0 11.556304 1.251200 1.240736
4 4.0 32.0 713.0 532.0 6.0 1.0 9.737841 1.186240 1.174128
5 4.0 3.0 712.0 542.0 31.0 31.0 19.394785 2.017056 2.310368
6 4.0 120.0 379.0 288.0 1.0 6.0 9.586752 0.828736 0.843712
7 1024.0 384.0 1.0 928.0 1.0 3.0 48.939903 5.529312 2.860768
8 4.0 24.0 687.0 512.0 6.0 1.0 13.474000 0.831920 0.825280
9 96.0 96.0 112.0 112.0 5.0 5.0 15.439168 2.611616 0.724864
10 96.0 80.0 56.0 56.0 5.0 5.0 5.991968 0.520352 0.207456
11 64.0 128.0 64.0 84.0 3.0 3.0 9.381472 0.609680 0.202832
12 16.0 960.0 7.0 7.0 5.0 5.0 0.265504 0.015680 0.014496
13 16.0 64.0 112.0 112.0 3.0 3.0 2.384832 0.187168 0.053280
```
Pull Request resolved: #125362
Approved by: https://github.com/ezyang1 parent 1370f3a commit 8bf9e99
1 file changed
Lines changed: 120 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
32 | 139 | | |
33 | 140 | | |
34 | | - | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
35 | 146 | | |
36 | 147 | | |
37 | 148 | | |
| |||
315 | 426 | | |
316 | 427 | | |
317 | 428 | | |
318 | | - | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
319 | 436 | | |
320 | 437 | | |
321 | 438 | | |
| |||
328 | 445 | | |
329 | 446 | | |
330 | 447 | | |
331 | | - | |
| 448 | + | |
332 | 449 | | |
333 | 450 | | |
334 | 451 | | |
| |||
0 commit comments