Skip to content

Commit 8bf9e99

Browse files
valentinandreipytorchmergebot
authored andcommitted
[pytorch][cuda] Some speedup on depth wise convolution 2D forward (#125362)
This PR does a few things: - Adds a generic implementation for `conv_depthwise2d` when the filter size is non standard. This implementation works faster because it doesn't do edge condition checks inside the innermost loops. We avoid the checks by calculating the boundaries ahead of the loop. - Hints to nvcc to minimize the register usage so that we squeeze more memory bandwidth - Adds filter size 5 as a common size where we can use the template implementation to improve unrolling and generate more efficient code The implementation doesn't completely fix the issue described in #18631. For that we need to rewrite the kernel using the suggestions described in the issue chat. This PR uses the same order of accessing the tensor as before but just removes overhead instructions in the inner loops to get the speedup. Before: ``` conv2d-performance: B C iH iW kH kW native (cpu) conv2d (cuda) conv2d-fp16 (cuda) 0 8.0 64.0 1024.0 1008.0 5.0 5.0 149.052643 24.982176 3.236192 1 8.0 64.0 1008.0 1008.0 5.0 5.0 150.810333 24.643536 3.237760 2 4.0 48.0 720.0 539.0 6.0 1.0 15.747776 2.636320 1.788672 3 4.0 120.0 379.0 283.0 6.0 1.0 12.234080 1.791712 1.231360 4 4.0 32.0 713.0 532.0 6.0 1.0 10.362272 1.731584 1.170544 5 4.0 3.0 712.0 542.0 31.0 31.0 24.965248 3.406304 4.165440 6 4.0 120.0 379.0 288.0 1.0 6.0 10.772512 1.215616 0.939936 7 1024.0 384.0 1.0 928.0 1.0 3.0 60.051582 7.594256 2.861344 8 4.0 24.0 687.0 512.0 6.0 1.0 10.231536 1.196704 0.818432 9 96.0 96.0 112.0 112.0 5.0 5.0 21.025631 5.110096 0.715520 10 96.0 80.0 56.0 56.0 5.0 5.0 9.730064 1.016080 0.207424 11 64.0 128.0 64.0 84.0 3.0 3.0 18.759552 0.616736 0.200832 12 16.0 960.0 7.0 7.0 5.0 5.0 0.274880 0.020288 0.014688 13 16.0 64.0 112.0 112.0 3.0 3.0 6.425696 0.189088 0.053728 ``` After ``` B C iH iW kH kW native (cpu) conv2d (cuda) conv2d-fp16 (cuda) 0 8.0 64.0 1024.0 1008.0 5.0 5.0 122.534370 12.915648 3.269936 1 8.0 64.0 1008.0 1008.0 5.0 5.0 126.026978 12.826848 3.236608 2 4.0 48.0 720.0 539.0 6.0 1.0 14.488160 1.803424 1.794368 3 4.0 120.0 379.0 283.0 6.0 1.0 11.556304 1.251200 1.240736 4 4.0 32.0 713.0 532.0 6.0 1.0 9.737841 1.186240 1.174128 5 4.0 3.0 712.0 542.0 31.0 31.0 19.394785 2.017056 2.310368 6 4.0 120.0 379.0 288.0 1.0 6.0 9.586752 0.828736 0.843712 7 1024.0 384.0 1.0 928.0 1.0 3.0 48.939903 5.529312 2.860768 8 4.0 24.0 687.0 512.0 6.0 1.0 13.474000 0.831920 0.825280 9 96.0 96.0 112.0 112.0 5.0 5.0 15.439168 2.611616 0.724864 10 96.0 80.0 56.0 56.0 5.0 5.0 5.991968 0.520352 0.207456 11 64.0 128.0 64.0 84.0 3.0 3.0 9.381472 0.609680 0.202832 12 16.0 960.0 7.0 7.0 5.0 5.0 0.265504 0.015680 0.014496 13 16.0 64.0 112.0 112.0 3.0 3.0 2.384832 0.187168 0.053280 ``` Pull Request resolved: #125362 Approved by: https://github.com/ezyang
1 parent 1370f3a commit 8bf9e99

1 file changed

Lines changed: 120 additions & 3 deletions

File tree

aten/src/ATen/native/cuda/DepthwiseConv2d.cu

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,120 @@ PackedTensorAccessor32<scalar_t, ndim, PtrTraits> dummy_packed_accessor32() {
2929
return {nullptr, zeros.data(), zeros.data()};
3030
}
3131

32+
template <typename scalar_t, typename index_t>
33+
__global__ void
34+
#if !defined(USE_ROCM)
35+
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
36+
#endif
37+
conv_depthwise2d_forward_kernel_generic(
38+
const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
39+
PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> output,
40+
const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
41+
const PackedTensorAccessor32<const scalar_t, 1, DefaultPtrTraits> bias,
42+
bool biasEnabled,
43+
index_t totalElements,
44+
const int outputChannels,
45+
const int depthwiseMultiplier,
46+
const int inputWidth, const int inputHeight,
47+
const int outputWidth, const int outputHeight,
48+
const int kernelWidth, const int kernelHeight,
49+
const int strideWidth, const int strideHeight,
50+
const int padWidth, const int padHeight,
51+
const int dilationWidth, const int dilationHeight) {
52+
using acc_t = at::acc_type<scalar_t, true>;
53+
54+
CUDA_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) {
55+
//calculate n,c,h,w indices, replacing modulos by divide and multiply add,
56+
//result is same as would be in the code below
57+
//const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth
58+
//const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth
59+
//const int h = (linearIndex / outputWidth) % outputHeight;
60+
//const int w = linearIndex % outputWidth;
61+
62+
int indtmp1 = linearIndex/outputWidth;
63+
const int w = linearIndex - indtmp1 * outputWidth;
64+
int indtmp2 = indtmp1/outputHeight;
65+
const int h = indtmp1 - indtmp2 * outputHeight;
66+
indtmp1 = indtmp2;
67+
indtmp2 = indtmp1/outputChannels;
68+
const int c = indtmp1 - indtmp2 * outputChannels;
69+
const int n = indtmp2;
70+
71+
int inputChannel = c;
72+
int inputChannels = outputChannels;
73+
if (depthwiseMultiplier !=1) {
74+
inputChannel /= depthwiseMultiplier;
75+
inputChannels /= depthwiseMultiplier;
76+
}
77+
78+
int weightOffset = c * kernelHeight * kernelWidth;
79+
80+
// By precisely computing the filtering boundaries, we avoid repeating several
81+
// expensive edge condition checks for every fetched item. If the input element is
82+
// resident in L1, then the extra branches and comparisons would have been
83+
// comparable in terms of cycles with the actual data fetch. Therefore computing
84+
// boundaries ahead of the loop showed significant performance boost.
85+
86+
int kHmin = 0, kHmax = kernelHeight, kWmin = 0, kWmax = kernelWidth;
87+
88+
// Top
89+
int h_in_min = -padHeight + h * strideHeight;
90+
if (h_in_min < 0) {
91+
kHmin = -h_in_min / dilationHeight;
92+
if ((-h_in_min) % dilationHeight > 0) {
93+
kHmin++;
94+
}
95+
}
96+
97+
// Bottom
98+
int h_in_max = h_in_min + (kernelHeight - 1) * dilationHeight - inputHeight + 1;
99+
if (h_in_max >= 0) {
100+
kHmax = kernelHeight - h_in_max / dilationHeight;
101+
if (h_in_max % dilationHeight > 0) {
102+
kHmax--;
103+
}
104+
}
105+
106+
// Left
107+
int w_in_min = -padWidth + w * strideWidth;
108+
if (w_in_min < 0) {
109+
kWmin = -w_in_min / dilationWidth;
110+
if ((-w_in_min) % dilationWidth > 0) {
111+
kWmin++;
112+
}
113+
}
114+
115+
// Right
116+
int w_in_max = w_in_min + (kernelWidth - 1) * dilationWidth - inputWidth + 1;
117+
if (w_in_max >= 0) {
118+
kWmax = kernelWidth - w_in_max / dilationWidth;
119+
if (w_in_max % dilationWidth > 0) {
120+
kWmax--;
121+
}
122+
}
123+
124+
acc_t value = biasEnabled ? static_cast<acc_t>(bias.data()[c]) : acc_t(0);
125+
const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
126+
127+
for (int kH = kHmin; kH < kHmax; ++kH) {
128+
const int h_in = -padHeight + h * strideHeight + kH * dilationHeight;
129+
for (int kW = kWmin; kW < kWmax; ++kW) {
130+
const int w_in = -padWidth + w * strideWidth + kW * dilationWidth;
131+
const index_t offset = offset0 + h_in * inputWidth + w_in;
132+
value += (static_cast<acc_t>(weight.data()[weightOffset + kH * kernelWidth + kW]) *
133+
static_cast<acc_t>(input.data()[offset]));
134+
}
135+
}
136+
output.data()[linearIndex] = static_cast<scalar_t>(value);
137+
}
138+
}
32139

33140
template <int kSize, typename scalar_t, typename index_t>
34-
__global__ void conv_depthwise2d_forward_kernel(
141+
__global__ void
142+
#if !defined(USE_ROCM)
143+
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
144+
#endif
145+
conv_depthwise2d_forward_kernel(
35146
const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> input,
36147
PackedTensorAccessor32<scalar_t, 4, DefaultPtrTraits> output,
37148
const PackedTensorAccessor32<const scalar_t, 4, DefaultPtrTraits> weight,
@@ -315,7 +426,13 @@ void conv_depthwise2d_forward_out(
315426
const auto bias_a = has_bias ?
316427
bias.packed_accessor32<const scalar_t, 1>() :
317428
dummy_packed_accessor32<const scalar_t, 1>();
318-
if (kW == 3 && kH == 3) {
429+
if (kW == 5 && kH == 5) {
430+
conv_depthwise2d_forward_kernel<5> <<<grid, block, 0, stream>>>(
431+
input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
432+
width, height, outputWidth, outputHeight,
433+
kW, kH, dW, dH, padW, padH, dilationW, dilationH);
434+
C10_CUDA_KERNEL_LAUNCH_CHECK();
435+
} else if (kW == 3 && kH == 3) {
319436
conv_depthwise2d_forward_kernel<3> <<<grid, block, 0, stream>>>(
320437
input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
321438
width, height, outputWidth, outputHeight,
@@ -328,7 +445,7 @@ void conv_depthwise2d_forward_out(
328445
kW, kH, dW, dH, padW, padH, dilationW, dilationH);
329446
C10_CUDA_KERNEL_LAUNCH_CHECK();
330447
} else {
331-
conv_depthwise2d_forward_kernel<0> <<<grid, block, 0, stream>>>(
448+
conv_depthwise2d_forward_kernel_generic<<<grid, block, 0, stream>>>(
332449
input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier,
333450
width, height, outputWidth, outputHeight,
334451
kW, kH, dW, dH, padW, padH, dilationW, dilationH);

0 commit comments

Comments
 (0)