Skip to content

Add padding='same' mode to conv{1,2,3}d#45667

Closed
peterbell10 wants to merge 19 commits intogh/peterbell10/17/basefrom
gh/peterbell10/17/head
Closed

Add padding='same' mode to conv{1,2,3}d#45667
peterbell10 wants to merge 19 commits intogh/peterbell10/17/basefrom
gh/peterbell10/17/head

Conversation

@peterbell10
Copy link
Copy Markdown
Collaborator

@peterbell10 peterbell10 commented Oct 1, 2020

First part of #3867 (Pooling operators still to do)

This adds a padding='same' mode to the interface of conv{n}dand nn.Conv{n}d. This should match the behaviour of tensorflow. I couldn't find it explicitly documented but through experimentation I found tensorflow returns the shape ceil(len/stride) and always adds any extra asymmetric padding onto the right side of the input.

Since the native_functions.yaml schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded conv{n}d and the _conv{n}d_same variant. Underscores because I couldn't see any way to avoid exporting a function into the torch namespace.

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use TORCH_WARN_ONCE to notify the user of the performance implications.

Stack from ghstack:

Differential Revision: D27170744

First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
@mrshenli mrshenli added module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 1, 2020
@mrshenli mrshenli requested a review from gchanan October 1, 2020 23:57
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Oct 5, 2020
ghstack-source-id: 7235eec
Pull Request resolved: pytorch#45667
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

Hi @peterbell10!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
@dr-ci
Copy link
Copy Markdown

dr-ci bot commented Dec 3, 2020

💊 CI failures summary and remediations

As of commit 9486ef1 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_clang7_onnx_ort_test2 (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Error: No such container:path: cac8e0abde32913e5079756d43ced1f93de2e3eb795bd7c5da321112ea25b9d8:/var/lib/jenkins/workspace/test/.coverage
Mar 17 16:08:01     * 319ab58e27          1 report,  total time   652.26s
Mar 17 16:08:01     * 790326d49b          1 report,  total time   718.46s
Mar 17 16:08:01     |
Mar 17 16:08:01     :
Mar 17 16:08:01 
Mar 17 16:08:01 Removed  (across    0 suites)     0 tests, totaling     0.00s
Mar 17 16:08:01 Modified (across    1 suite)      2 tests, totaling + 230.40s ±   2.82s
Mar 17 16:08:01 Added    (across    0 suites)     0 tests, totaling     0.00s
Retrieving test reports
Retrieving Python coverage report
Error: No such container:path: cac8e0abde32913e5079756d43ced1f93de2e3eb795bd7c5da321112ea25b9d8:/var/lib/jenkins/workspace/test/.coverage


Exited with code exit status 1


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
@peterbell10
Copy link
Copy Markdown
Collaborator Author

peterbell10 commented Dec 4, 2020

The XLA test compile failure is real:

return torch::conv3d(inputs[0], inputs[1], inputs[2],
                      /*stride=*/{stride, stride, stride},
                      /*padding=*/{padding, padding, padding},
                      /*dilation=*/{dilation, dilation, dilation},
                      groups);

The padding argument could implicitly convert to either IntArrayRef or std::string, so the call is ambiguous. Calling with an actual array, or specifying IntArrayRef{padding, padding, padding} fixes it.

I don't think there's an easy fix for this. Using enum values in the C++ api (#24343) should avoid the ambiguity but I think that would require substantial codegen and/or JIT changes.

First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
Copy link
Copy Markdown
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay- finally took the time to dig through this today in extreme detail.

Logic is correct, support is correctly restricted to stride=1 for this PR, test coverage for the python side looks solid.

One request: since the padding mode handling logic is duplicated on the C++ module side, I think it'd be valuable to validate the various cases (valid padding, same padding + each padding mode, error case of same padding with stride > 1) with C++ module tests in test/cpp/api/modules.cpp.

peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Mar 11, 2021
ghstack-source-id: ff195ad
Pull Request resolved: pytorch#45667
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
@peterbell10 peterbell10 requested a review from albanD as a code owner March 15, 2021 19:43
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Mar 15, 2021
ghstack-source-id: 4291813
Pull Request resolved: pytorch#45667
Copy link
Copy Markdown
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! could you please rebase to fix the merge conflicts?

First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. 

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.




[ghstack-poisoned]
@peterbell10
Copy link
Copy Markdown
Collaborator Author

Fixed merge conflicts, CI failure is unrelated:

Mar 17 16:05:40 + pip install -q --user git+https://github.com/pytorch/vision.git@8a2dc6f22ac4389ccba8859aa1e1cb14f1ee53db
Mar 17 16:07:53   ERROR: Command errored out with exit status 128:
Mar 17 16:07:53    command: git fetch -q https://github.com/pytorch/vision.git 8a2dc6f22ac4389ccba8859aa1e1cb14f1ee53db
Mar 17 16:07:53        cwd: /tmp/pip-req-build-jueuvyvq
Mar 17 16:07:53   Complete output (1 lines):
Mar 17 16:07:53   fatal: unable to access 'https://github.com/pytorch/vision.git/': Failed to connect to github.com port 443: Connection timed out

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@jbschlosser merged this pull request in 04e0cbf.

@jbschlosser jbschlosser added this to the 1.9.0 milestone Mar 19, 2021
@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/17/head branch March 22, 2021 14:16
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')

assert isinstance(self.padding, tuple)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this assertion is wrong. The following code will crash because padding is list, not tuple

import torch

x = torch.randn(4, 8, 32, 32, device='cuda', dtype=torch.float)
net = torch.nn.ConvTranspose2d(8, 16, kernel_size=3, padding=[3, 3])
net = net.cuda()

y = net(x)
torch.cuda.synchronize()

error message:

Traceback (most recent call last):
  File "/home/xwang/Developer/pytorch-test/random-things/padding-tuple.py", line 7, in <module>
    y = net(x)
  File "/home/xwang/Developer/pytorch/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/xwang/Developer/pytorch/torch/nn/modules/conv.py", line 890, in forward
    assert isinstance(self.padding, tuple)
AssertionError

Can you please fix it? @peterbell10

Copy link
Copy Markdown
Collaborator

@ngimel ngimel Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xwang233 thanks for pointing out! Can you please file a separate issue? Oh, nvm, you did already, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants