Skip to content

Add Support for transposed convolution with Padding Mode 'same' and 'valid'#154279

Open
Alvaro-Kothe wants to merge 9 commits intopytorch:mainfrom
Alvaro-Kothe:feat/conv-transpose-pad-same-cpp
Open

Add Support for transposed convolution with Padding Mode 'same' and 'valid'#154279
Alvaro-Kothe wants to merge 9 commits intopytorch:mainfrom
Alvaro-Kothe:feat/conv-transpose-pad-same-cpp

Conversation

@Alvaro-Kothe
Copy link
Copy Markdown
Contributor

@Alvaro-Kothe Alvaro-Kothe commented May 23, 2025

This pull request adds support for 'same' and 'valid' padding modes for transposed convolutions.

Implementation Details for 'same' padding:

  1. Compute the minimum required padding for the left and right sides.
  2. Perform the convolution using the computed padding.
  3. Adjust the output so that its size matches the input size.

Notes:

  • The padding calculation is based on the JAX implementation.

  • For asymmetric convolutions, I slice one extra value on the left side.

  • 'valid' padding is equivalent to padding=0.

  • 'same' padding is not supported when:

    • The stride is not equal to 1 (stride != 1), as it is also unsupported in traditional convolutions, and I could not find any bibliographic references clarifying how it could be implemented in this context.
    • output_padding != 0, since this seems to conflict with the intent of 'same' padding.

Close #80301; Close #3867

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168

@Alvaro-Kothe Alvaro-Kothe requested a review from albanD as a code owner May 23, 2025 21:12
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented May 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154279

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: cpp release notes category labels May 23, 2025
@github-actions
Copy link
Copy Markdown
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@shoumikhin
Copy link
Copy Markdown
Contributor

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased feat/conv-transpose-pad-same-cpp onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout feat/conv-transpose-pad-same-cpp && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the feat/conv-transpose-pad-same-cpp branch from 42ce8c0 to 0bfa1b8 Compare May 24, 2025 04:10
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

Could someone please trigger the CI?
Soft ping to: @Skylion007 @albanD @shoumikhin. Thank you in advance!

@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 29, 2025
@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jul 28, 2025
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased feat/conv-transpose-pad-same-cpp onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout feat/conv-transpose-pad-same-cpp && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the feat/conv-transpose-pad-same-cpp branch from c71a0b5 to e1bdc0c Compare August 4, 2025 11:33
@Alvaro-Kothe Alvaro-Kothe force-pushed the feat/conv-transpose-pad-same-cpp branch from e1bdc0c to d6c4409 Compare August 22, 2025 12:03
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

Hi @Skylion007, could you please take another look at this PR when you have a moment? Also, if possible, remove the Stale label.

@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased feat/conv-transpose-pad-same-cpp onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout feat/conv-transpose-pad-same-cpp && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the feat/conv-transpose-pad-same-cpp branch from d6c4409 to 7e5a0e5 Compare September 10, 2025 20:00
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

Can somebody remove the Stale label and trigger the CI?

Also, if there are any changes needed in this PR, please let me know.

Soft ping to: @Skylion007, @albanD, @shoumikhin and @colesbury. Thank you in advance!

@albanD albanD removed the Stale label Sep 12, 2025
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

Alvaro-Kothe commented Sep 15, 2025

Thanks for removing the stale label and triggering the CI.

The build on the failing job was killed with SIGKILL

sccache: Compiler killed by signal 9

I think this is unrelated to the changes in this PR.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

I was investigating the errors, and I can at least understand why.

Traceback

It originates from the changes in common_method_invocations.py

Where for the failing test, as the OpInfo contains this ref (doesn't exist for traditional convolution)

    OpInfo('nn.functional.conv_transpose1d',
           # `ref` for this function is backward of
           # corresponding `conv*d`
           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d),

The function defined in the ref calls for these gradient functions

    grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input,
                   torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input,
                   torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input}

Which calls for convolution_backward, which doesn't support padding as a string.

    return torch.ops.aten.convolution_backward(
        grad_output,
        input,
        weight,
        None,
        _single(stride),
        _single(padding),
        _single(dilation),
        False,
        [0],
        groups,
        (True, False, False),
    )[0]

Why it doesn't happen for normal convolution?

It doesn't happen for normal convolution because they don't have this ref

    OpInfo('nn.functional.conv1d',
           aliases=('conv1d',),
           aten_name='conv1d',
           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,

The conv{1,2,3}d_input functions defined in torch/nn/grad.py doesn't use padding as a string with regular convolution in the codebase.

Next Steps

I see 3 main options:

  1. Conditionally skip this test and similar ones.
    • I couldn't find in the class OpInfo a way to skip a test based on the function arguments, so I probably would have to skip in the test body.
  2. Remove the changes in common_methods_invocations to don't run this test.
  3. Create a new overload for convolution_backward to support padding as a string.
    • I don't know how feasible this is for me, because it seems this is specific for GPU, and I don't have easy access to one.

cc @albanD @jbschlosser

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Oct 10, 2025

Ho that is very interesting!
Thanks for the details on the investigation!

Let's double check with @jbschlosser if 1) is feasible.
Otherwise 2 is ok.

@jbschlosser
Copy link
Copy Markdown
Contributor

Skipping / xfailing a particular test only for a given set of samples (AKA option 1 above) is possible, albeit somewhat involved. I think an option 4 could be to update the transposed convolution ref to handle string padding, no?

@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

I think an option 4 could be to update the transposed convolution ref to handle string padding, no?

I will look into it. I still haven't looked what convolution_backward does, but if it's possible to simply do what I did in the C++ definitely is possible!

What I am thinking in doing is the following in the conv_transpose_ref function:

  • If it's "valid": replace padding with 0
  • If it's "same":
    1. compute the minimum padding;
    2. replace the padding argument with the computed one;
    3. adjust the out value if it's asymmetric.

Is this a valid solution, or do I need to do something else?

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Oct 13, 2025
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

Alvaro-Kothe commented Oct 13, 2025

Hi @jbschlosser, Can you check the latest commit (c113a2d) and verify if it's what you had in mind?

@Alvaro-Kothe Alvaro-Kothe force-pushed the feat/conv-transpose-pad-same-cpp branch from 7438017 to c113a2d Compare December 7, 2025 00:39
@Alvaro-Kothe Alvaro-Kothe force-pushed the feat/conv-transpose-pad-same-cpp branch from c113a2d to c89127b Compare January 18, 2026 14:14
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased feat/conv-transpose-pad-same-cpp onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout feat/conv-transpose-pad-same-cpp && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the feat/conv-transpose-pad-same-cpp branch from c89127b to 42e74b0 Compare February 7, 2026 14:44
@Alvaro-Kothe
Copy link
Copy Markdown
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased feat/conv-transpose-pad-same-cpp onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout feat/conv-transpose-pad-same-cpp && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the feat/conv-transpose-pad-same-cpp branch from 42e74b0 to 1abc48c Compare March 4, 2026 00:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: cpp release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Need "valid" and "same" padding mode for convTranspose2d [Feature Request] Implement "same" padding for convolution operations?

8 participants