Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5d695a4
Added same padding for conv*d
Chillee Jul 3, 2019
91765f0
Updated documentation
Chillee Jul 3, 2019
65c77f9
Added static padding calculation when possible
Chillee Jul 8, 2019
dd40227
Restructured code to support the JIT
Chillee Jul 8, 2019
65761de
Added tests
Chillee Jul 9, 2019
9474b15
Fix lint
Chillee Jul 9, 2019
c44d823
Started moving implementation to functional module
Chillee Jul 13, 2019
52c7af8
started implementing fmassa
Chillee Jul 19, 2019
cfce37c
Proof of concept
Chillee Jul 20, 2019
d6e3cd4
merged
Chillee Aug 8, 2019
b1fc70d
Added first attempt
Chillee Aug 9, 2019
609ecdf
Added things properly on the new overloading syntax
Chillee Aug 9, 2019
6a64502
Added python 2 support hopefully
Chillee Aug 9, 2019
a8020f7
added 1d and 3d conv too
Chillee Aug 9, 2019
2d8a873
Fixed stupid bug
Chillee Aug 9, 2019
2e45837
Maybe fixed CI issues. Can't remember what I was doing with this code...
Chillee Aug 16, 2019
20ef7e9
Fix padding_mode for all conv sizes
Chillee Aug 16, 2019
ad5ff2d
Fix formatting issues
Chillee Aug 16, 2019
925ee7c
Fixed circular padding issues
Chillee Aug 16, 2019
eb4e75a
Fixed issue in conv1d
Chillee Aug 17, 2019
3c6108f
Fix some misc issues
Chillee Aug 18, 2019
63454be
Fixed faulty type annotation
Chillee Aug 19, 2019
4aa12f2
Fixed broadcastinglist issue
Chillee Aug 20, 2019
5939c7c
Made split_padding have an underscore
Chillee Aug 20, 2019
6f7b271
Merge remote-tracking branch 'origin/master' into paddingsame
Chillee Aug 21, 2019
b9290e7
merged
Chillee Aug 21, 2019
d2bf119
Fixed python2 error
Chillee Aug 21, 2019
7841a4d
Fixed linting issues
Chillee Aug 21, 2019
5a84d5f
Fixed stupid issue with test
Chillee Aug 22, 2019
42ca33c
Fixed some issues
Chillee Aug 29, 2019
0c28ce4
Merge branch 'master' of github.com:pytorch/pytorch into paddingsame
Chillee Aug 30, 2019
731119a
Merge branch 'master' of github.com:pytorch/pytorch into paddingsame
Chillee Sep 4, 2019
3fb6287
Merge branch 'master' of github.com:pytorch/pytorch into paddingsame
Chillee Sep 9, 2019
d90e816
Merge branch 'master' of github.com:pytorch/pytorch into paddingsame
Chillee Sep 20, 2019
3ae61e0
rerun tests
Chillee Sep 20, 2019
aa119fd
Fixed tests
Chillee Sep 22, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,37 @@ at::Tensor _convolution(
int64_t dim = k - 2;

TORCH_CHECK(dim > 0, "weight should have at least three dimensions");

std::vector<int64_t> new_padding = padding_.vec();
// If we're passing in asymmetric padding
if (padding_.size() == size_t(dim) * 2) {
std::vector<int64_t> asymmetric;
bool is_uneven = false;
for (size_t i = 0; i < padding_.size(); i += 2) {
if (padding_[i] > padding_[i+1]) {
asymmetric.push_back(0);
asymmetric.push_back(padding_[i] - padding_[i+1]);
is_uneven = true;
} else if (padding_[i] < padding_[i+1]) {
asymmetric.push_back(padding_[i+1] - padding_[i]);
asymmetric.push_back(0);
is_uneven = true;
} else {
asymmetric.push_back(0);
asymmetric.push_back(0);
}
}
std::reverse(asymmetric.begin(), asymmetric.end());
if (is_uneven) {
input = at::constant_pad_nd(input, IntArrayRef{asymmetric});
}
new_padding.clear();
for (size_t i = 0; i < padding_.size(); i += 2) {
new_padding.push_back(std::min(padding_[i], padding_[i+1]));
}
}
ConvParams params;
params.stride = expand_param_if_needed(stride_, "stride", dim);
params.padding = expand_param_if_needed(padding_, "padding", dim);
params.padding = expand_param_if_needed(IntArrayRef{new_padding}, "padding", dim);
params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
params.transposed = transposed_;
params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
Expand Down
70 changes: 70 additions & 0 deletions test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,26 @@ def fractional_max_pool3d_test(test_case):
input_size=(2, 4, 6),
cudnn=True,
),
dict(
fullname='Conv1d_padding_same_static',
constructor=lambda: nn.Conv1d(4, 4, 1, padding="same"),
input_size=(2, 4, 8),
),
dict(
fullname='Conv1d_padding_same_offset',
constructor=lambda: nn.Conv1d(4, 4, 2, padding="same"),
input_size=(2, 4, 8),
),
dict(
fullname='Conv1d_padding_same_dynamic',
constructor=lambda: nn.Conv1d(4, 4, 5, 3, padding="same"),
input_size=(2, 4, 10),
),
dict(
fullname='Conv1d_padding_same_all_params1',
constructor=lambda: nn.Conv1d(4, 4, 3, stride=2, dilation=3, padding="same"),
input_size=(2, 4, 20),
),
dict(
fullname='ConvTranspose1d',
constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
Expand Down Expand Up @@ -1375,6 +1395,31 @@ def fractional_max_pool3d_test(test_case):
constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
input_size=(2, 4, 5, 5),
),
dict(
fullname='Conv2d_padding_same_static',
constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding="same"),
input_size=(2, 4, 8, 8),
),
dict(
fullname='Conv2d_padding_same_offset',
constructor=lambda: nn.Conv2d(4, 4, (2, 2), padding="same"),
input_size=(2, 4, 8, 8),
),
dict(
fullname='Conv2d_padding_same_dynamic',
constructor=lambda: nn.Conv2d(4, 4, (5, 5), 3, padding="same"),
input_size=(2, 4, 10, 10),
),
dict(
fullname='Conv2d_padding_same_all_params1',
constructor=lambda: nn.Conv2d(4, 4, (3, 2), stride=(2, 3), dilation=(3, 5), padding="same"),
input_size=(2, 4, 20, 15),
),
dict(
fullname='Conv2d_padding_same_all_params2',
constructor=lambda: nn.Conv2d(4, 4, (3, 2), stride=(2, 3), dilation=(3, 5), padding="same"),
input_size=(2, 4, 20, 20),
),
dict(
module_name='MaxPool2d',
constructor_args=((3, 3), (2, 2), (1, 1)),
Expand Down Expand Up @@ -1576,6 +1621,31 @@ def fractional_max_pool3d_test(test_case):
constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
input_size=(2, 3, 5, 5, 5),
),
dict(
fullname='Conv3d_padding_same_static',
constructor=lambda: nn.Conv3d(4, 4, (3, 3, 3), padding="same"),
input_size=(2, 4, 8, 8, 8),
),
dict(
fullname='Conv3d_padding_same_offset',
constructor=lambda: nn.Conv3d(4, 4, (2, 2, 2), padding="same"),
input_size=(2, 4, 8, 8, 8),
),
dict(
fullname='Conv3d_padding_same_dynamic',
constructor=lambda: nn.Conv3d(4, 4, (5, 5, 5), 3, padding="same"),
input_size=(2, 4, 10, 10, 10),
),
dict(
fullname='Conv3d_padding_same_all_params1',
constructor=lambda: nn.Conv3d(1, 1, (3, 2, 1), stride=(2, 3, 1), dilation=(3, 5, 3), padding="same"),
input_size=(1, 1, 20, 15, 13),
),
dict(
fullname='Conv3d_padding_same_all_params2',
constructor=lambda: nn.Conv3d(1, 1, (3, 2, 3), stride=(2, 3, 1), dilation=(3, 5, 2), padding="same"),
input_size=(1, 1, 20, 20, 20),
),
dict(
module_name='ConvTranspose3d',
constructor_args=(2, 3, (2, 3, 2)),
Expand Down
7 changes: 4 additions & 3 deletions test/test_docs_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ def test_torch(self):

# below are symbols mistakely binded to torch.*, but should
# go to torch.nn.functional.* instead
'avg_pool1d', 'conv_transpose2d', 'conv_transpose1d', 'conv3d',
'relu_', 'pixel_shuffle', 'conv2d', 'selu_', 'celu_', 'threshold_',
'cosine_similarity', 'rrelu_', 'conv_transpose3d', 'conv1d', 'pdist',
'avg_pool1d', 'conv_transpose2d', 'conv_transpose1d',
'relu_', 'pixel_shuffle', 'selu_', 'celu_', 'threshold_',
'cosine_similarity', 'rrelu_', 'conv_transpose3d', 'pdist',
'adaptive_avg_pool1d', 'conv_tbc'
}
has_docstring = set(
a for a in dir(torch)
if getattr(torch, a).__doc__ and not a.startswith('_') and
'function' in type(getattr(torch, a)).__name__)
print(has_docstring & whitelist)
self.assertEqual(
has_docstring & whitelist, whitelist,
textwrap.dedent('''
Expand Down
17 changes: 17 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13986,6 +13986,22 @@ def test_uses():

self.checkScript(test_uses, ())

def test_padding_same_overload(self):
class W(torch.nn.Module):
def __init__(self):
super(W, self).__init__()
self.pool = nn.MaxPool1d(2, stride=2, return_indices=True)
self.unpool = nn.MaxUnpool1d(2, stride=2)
self.layer = nn.Conv1d(10, 3, 3, padding="same")

def forward(self, x):
input = torch.tensor([[[1., 2., 3., 4., 5., 6., 7., 8.]]])
output, indices = self.pool(input)
self.unpool(output, indices)
return self.layer(x)

torch.jit.script(W())(torch.randn((10, 10, 10)))

def test_method_overloading(self):
class Over(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -14139,6 +14155,7 @@ def forward(self, x):
with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
a = torch.jit.script(W2())


def test_select_after_chunk(self):
def foo(x):
chunked = torch.chunk(x, 1)
Expand Down
12 changes: 11 additions & 1 deletion torch/jit/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import torch
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
BroadcastingList3, BroadcastingList4, BroadcastingList6, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
is_optional, _qualified_name
from torch._C import TensorType, TupleType, FloatType, IntType, \
ListType, StringType, DictType, BoolType, OptionalType, ClassType
Expand Down Expand Up @@ -35,6 +35,11 @@ def __getattr__(self, name):
'List': List,
'Dict': Dict,
'Optional': Optional,
'BroadcastingList1': BroadcastingList1,
'BroadcastingList2': BroadcastingList2,
'BroadcastingList3': BroadcastingList3,
'BroadcastingList4': BroadcastingList4,
'BroadcastingList6': BroadcastingList6
}
class EvalEnv(object):
env = {
Expand All @@ -45,6 +50,11 @@ class EvalEnv(object):
'List': List,
'Dict': Dict,
'Optional': Optional,
'BroadcastingList1': BroadcastingList1,
'BroadcastingList2': BroadcastingList2,
'BroadcastingList3': BroadcastingList3,
'BroadcastingList4': BroadcastingList4,
'BroadcastingList6': BroadcastingList6
}

def __init__(self, rcb):
Expand Down
Loading