@@ -644,43 +644,43 @@ static Tensor convolution_same(
644644 dilation, false , output_padding, groups);
645645}
646646
647- at:: Tensor conv1d (
647+ Tensor _convolution_mode (
648648 const Tensor& input, const Tensor& weight, const Tensor& bias,
649649 IntArrayRef stride, std::string padding, IntArrayRef dilation,
650650 int64_t groups) {
651651 if (padding == " same" ) {
652- return at::native::convolution_same (input, weight, bias, stride, dilation, groups);
652+ return at::native::convolution_same (
653+ input, weight, bias, stride, dilation, groups);
653654 } else if (padding == " valid" ) {
654655 const int64_t padding_[] = {0 };
655- return at::native::conv1d (input, weight, bias, stride, padding_, dilation, groups);
656+ return at::native::convolution (
657+ input, weight, bias, stride, padding_, dilation, false , padding_, groups);
656658 }
657659 TORCH_CHECK (false , " Invalid padding mode '" , padding, " '" );
658660}
659661
662+ at::Tensor conv1d (
663+ const Tensor& input, const Tensor& weight, const Tensor& bias,
664+ IntArrayRef stride, std::string padding, IntArrayRef dilation,
665+ int64_t groups) {
666+ return at::_convolution_mode (
667+ input, weight, bias, stride, std::move (padding), dilation, groups);
668+ }
669+
660670at::Tensor conv2d (
661671 const Tensor& input, const Tensor& weight, const Tensor& bias,
662672 IntArrayRef stride, std::string padding, IntArrayRef dilation,
663673 int64_t groups) {
664- if (padding == " same" ) {
665- return at::native::convolution_same (input, weight, bias, stride, dilation, groups);
666- } else if (padding == " valid" ) {
667- const int64_t padding_[] = {0 , 0 };
668- return at::native::conv2d (input, weight, bias, stride, padding_, dilation, groups);
669- }
670- TORCH_CHECK (false , " Invalid padding mode '" , padding, " '" );
674+ return at::_convolution_mode (
675+ input, weight, bias, stride, std::move (padding), dilation, groups);
671676}
672677
673678at::Tensor conv3d (
674679 const Tensor& input, const Tensor& weight, const Tensor& bias,
675680 IntArrayRef stride, std::string padding, IntArrayRef dilation,
676681 int64_t groups) {
677- if (padding == " same" ) {
678- return at::native::convolution_same (input, weight, bias, stride, dilation, groups);
679- } else if (padding == " valid" ) {
680- const int64_t padding_[] = {0 , 0 , 0 };
681- return at::native::conv3d (input, weight, bias, stride, padding_, dilation, groups);
682- }
683- TORCH_CHECK (false , " Invalid padding mode '" , padding, " '" );
682+ return at::_convolution_mode (
683+ input, weight, bias, stride, std::move (padding), dilation, groups);
684684}
685685
686686at::Tensor conv_transpose1d (
0 commit comments