Skip to content

Commit 45473ff

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
Refactor cudnn convolution (#49109)
Summary: cuDNN v7 API has been deprecated, so we need to migrate to cuDNN v8 API. The v8 API does not exist on cuDNN 7, so there will be a long time both API should exist. This is step 0 of adding cuDNN v8 API. There is no real code change in this PR. It just copy-pastes existing code. The original `Conv.cpp` is split into `ConvPlaceholders.cpp`, `ConvShared.cpp`, `ConvShared.h`, `Conv_v7.cpp`, `Conv_v8.cpp`. Currently `Conv_v8.cpp` is empty, and will be filled in the future. The `ConvPlaceholders.cpp` contains placeholder implementation of cudnn convolution when cudnn is not enabled. These operators only raise errors and do no real computation. This file also contains deprecated operators. These operators are implemented using current operators. The `ConvShared.cpp` and `ConvShared.h` contains code that will be shared by the v7 and v8 API, these include the definition of struct `ConvolutionParams` and `ConvolutionArgs`. As well as ATen exposed API like `cudnn_convolution` and intermediate `cudnn_convolution_forward`. These exposed functions will call raw API like `raw_cudnn_convolution_forward_out` in `Conv_v7.cpp` or `Conv_v8.cpp` for the real implementation. The `Conv_v7.cpp`, `Conv_v8.cpp` contains the implementation of raw APIs, and are different for v7 and v8. Pull Request resolved: #49109 Reviewed By: H-Huang Differential Revision: D25463783 Pulled By: ezyang fbshipit-source-id: 1c80de8e5d94d97a61e45687f6193e8ff5481e3e
1 parent d5c4a80 commit 45473ff

5 files changed

Lines changed: 746 additions & 614 deletions

File tree

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
2+
#include <ATen/ATen.h>
3+
#include <ATen/native/ConvUtils.h>
4+
5+
namespace at { namespace native {
6+
7+
// ---------------------------------------------------------------------
8+
//
9+
// Placeholder operators
10+
//
11+
// ---------------------------------------------------------------------
12+
13+
#if !AT_CUDNN_ENABLED()
14+
15+
// See Note [ATen preprocessor philosophy]
16+
17+
at::Tensor cudnn_convolution(
18+
const at::Tensor& input, const at::Tensor& weight,
19+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
20+
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
21+
AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support");
22+
}
23+
24+
at::Tensor cudnn_convolution_backward_input(
25+
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
26+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
27+
bool benchmark, bool deterministic, bool allow_tf32) {
28+
AT_ERROR("cudnn_convolution_backward_input: ATen not compiled with cuDNN support");
29+
}
30+
31+
at::Tensor cudnn_convolution_backward_weight(
32+
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
33+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
34+
bool benchmark, bool deterministic, bool allow_tf32) {
35+
AT_ERROR("cudnn_convolution_backward_weight: ATen not compiled with cuDNN support");
36+
}
37+
38+
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
39+
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
40+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
41+
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
42+
AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support");
43+
}
44+
45+
at::Tensor cudnn_convolution_transpose(
46+
const at::Tensor& input, const at::Tensor& weight,
47+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
48+
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
49+
AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support");
50+
}
51+
52+
at::Tensor cudnn_convolution_transpose_backward_input(
53+
const at::Tensor& grad_output, const at::Tensor& weight,
54+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
55+
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) {
56+
AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
57+
}
58+
59+
at::Tensor cudnn_convolution_transpose_backward_weight(
60+
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
61+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
62+
bool benchmark, bool deterministic, bool allow_tf32) {
63+
AT_ERROR("cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support");
64+
}
65+
66+
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
67+
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
68+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
69+
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask) {
70+
AT_ERROR("cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
71+
}
72+
73+
void raw_cudnn_convolution_forward_out(
74+
const Tensor& output, const Tensor& input, const Tensor& weight,
75+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
76+
bool benchmark, bool deterministic, bool allow_tf32) {
77+
AT_ERROR("raw_cudnn_convolution_forward_out: ATen not compiled with cuDNN support");
78+
}
79+
80+
void raw_cudnn_convolution_backward_input_out(
81+
const at::Tensor& grad_input,
82+
const at::Tensor& grad_output,
83+
const at::Tensor& weight,
84+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
85+
bool benchmark, bool deterministic, bool allow_tf32) {
86+
AT_ERROR("raw_cudnn_convolution_backward_input_out: ATen not compiled with cuDNN support");
87+
}
88+
89+
void raw_cudnn_convolution_backward_weight_out(
90+
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
91+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
92+
bool benchmark, bool deterministic, bool allow_tf32) {
93+
AT_ERROR("raw_cudnn_convolution_backward_weight_out: ATen not compiled with cuDNN support");
94+
}
95+
96+
#endif // AT_CUDNN_ENABLED
97+
98+
// ---------------------------------------------------------------------
99+
//
100+
// Deprecated operators
101+
//
102+
// ---------------------------------------------------------------------
103+
104+
// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future
105+
Tensor cudnn_convolution_deprecated(
106+
const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */,
107+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
108+
int64_t groups, bool benchmark, bool deterministic) {
109+
auto output = at::cudnn_convolution(input, weight, padding, stride, dilation, groups, benchmark, deterministic);
110+
if (bias.defined()) {
111+
output = output + reshape_bias(input.dim(), bias);
112+
}
113+
return output;
114+
}
115+
116+
// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future
117+
Tensor cudnn_convolution_deprecated2(
118+
const Tensor& input_t, const Tensor& weight_t,
119+
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
120+
int64_t groups, bool benchmark, bool deterministic)
121+
{
122+
return at::cudnn_convolution(input_t, weight_t, padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN());
123+
}
124+
125+
// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future
126+
Tensor cudnn_convolution_transpose_deprecated(
127+
const Tensor& input, const Tensor& weight, const Tensor& bias /* optional */,
128+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
129+
int64_t groups, bool benchmark, bool deterministic)
130+
{
131+
auto output = at::cudnn_convolution_transpose(input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic);
132+
if (bias.defined()) {
133+
output = output + reshape_bias(input.dim(), bias);
134+
}
135+
return output;
136+
}
137+
138+
// TODO (@zasdfgbnm): this is here only for compatibility, remove this in the future
139+
Tensor cudnn_convolution_transpose_deprecated2(
140+
const Tensor& input_t, const Tensor& weight_t,
141+
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
142+
int64_t groups, bool benchmark, bool deterministic)
143+
{
144+
return at::cudnn_convolution_transpose(input_t, weight_t, padding, output_padding, stride, dilation, groups, benchmark, deterministic, at::globalContext().allowTF32CuDNN());
145+
}
146+
147+
}}

0 commit comments

Comments
 (0)