Skip to content

Commit 28ff463

Browse files
committed
Hide convolution_same from jit IR
1 parent efa203b commit 28ff463

2 files changed

Lines changed: 20 additions & 17 deletions

File tree

aten/src/ATen/native/Convolution.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -644,43 +644,43 @@ static Tensor convolution_same(
644644
dilation, false, output_padding, groups);
645645
}
646646

647-
at::Tensor conv1d(
647+
Tensor _convolution_mode(
648648
const Tensor& input, const Tensor& weight, const Tensor& bias,
649649
IntArrayRef stride, std::string padding, IntArrayRef dilation,
650650
int64_t groups) {
651651
if (padding == "same") {
652-
return at::native::convolution_same(input, weight, bias, stride, dilation, groups);
652+
return at::native::convolution_same(
653+
input, weight, bias, stride, dilation, groups);
653654
} else if (padding == "valid") {
654655
const int64_t padding_[] = {0};
655-
return at::native::conv1d(input, weight, bias, stride, padding_, dilation, groups);
656+
return at::native::convolution(
657+
input, weight, bias, stride, padding_, dilation, false, padding_, groups);
656658
}
657659
TORCH_CHECK(false, "Invalid padding mode '", padding, "'");
658660
}
659661

662+
at::Tensor conv1d(
663+
const Tensor& input, const Tensor& weight, const Tensor& bias,
664+
IntArrayRef stride, std::string padding, IntArrayRef dilation,
665+
int64_t groups) {
666+
return at::_convolution_mode(
667+
input, weight, bias, stride, std::move(padding), dilation, groups);
668+
}
669+
660670
at::Tensor conv2d(
661671
const Tensor& input, const Tensor& weight, const Tensor& bias,
662672
IntArrayRef stride, std::string padding, IntArrayRef dilation,
663673
int64_t groups) {
664-
if (padding == "same") {
665-
return at::native::convolution_same(input, weight, bias, stride, dilation, groups);
666-
} else if (padding == "valid") {
667-
const int64_t padding_[] = {0, 0};
668-
return at::native::conv2d(input, weight, bias, stride, padding_, dilation, groups);
669-
}
670-
TORCH_CHECK(false, "Invalid padding mode '", padding, "'");
674+
return at::_convolution_mode(
675+
input, weight, bias, stride, std::move(padding), dilation, groups);
671676
}
672677

673678
at::Tensor conv3d(
674679
const Tensor& input, const Tensor& weight, const Tensor& bias,
675680
IntArrayRef stride, std::string padding, IntArrayRef dilation,
676681
int64_t groups) {
677-
if (padding == "same") {
678-
return at::native::convolution_same(input, weight, bias, stride, dilation, groups);
679-
} else if (padding == "valid") {
680-
const int64_t padding_[] = {0, 0, 0};
681-
return at::native::conv3d(input, weight, bias, stride, padding_, dilation, groups);
682-
}
683-
TORCH_CHECK(false, "Invalid padding mode '", padding, "'");
682+
return at::_convolution_mode(
683+
input, weight, bias, stride, std::move(padding), dilation, groups);
684684
}
685685

686686
at::Tensor conv_transpose1d(

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,9 @@
938938
- func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor
939939
use_c10_dispatcher: full
940940

941+
- func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[] dilation, int groups) -> Tensor
942+
use_c10_dispatcher: full
943+
941944
- func: _convolution_nogroup(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding) -> Tensor
942945
use_c10_dispatcher: full
943946

0 commit comments

Comments
 (0)