Skip to content

add type annotations to torch.nn.modules.pooling#49504

Closed
guilhermeleobas wants to merge 5 commits intopytorch:masterfrom
guilhermeleobas:pooling
Closed

add type annotations to torch.nn.modules.pooling#49504
guilhermeleobas wants to merge 5 commits intopytorch:masterfrom
guilhermeleobas:pooling

Conversation

@guilhermeleobas
Copy link
Copy Markdown
Collaborator

closes gh-49503

@guilhermeleobas guilhermeleobas added the module: typing Related to mypy type annotations label Dec 16, 2020
@guilhermeleobas guilhermeleobas self-assigned this Dec 16, 2020
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Dec 16, 2020

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 07fc24d (more details on the Dr. CI page):


Commit 07fc24d was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Copy Markdown
Contributor

@walterddr walterddr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm.
nit: seems like we can stripe the definition in AvgPool1d as well (line 519-521)

Copy link
Copy Markdown
Collaborator

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a small tweak, ratios are floats.

Comment thread torch/nn/modules/pooling.py Outdated
Comment thread torch/nn/modules/pooling.py Outdated
@mruberry mruberry added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: nn Related to torch.nn labels Dec 28, 2020
@rgommers
Copy link
Copy Markdown
Collaborator

@walterddr would you be able to land this PR?

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@walterddr has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@guilhermeleobas
Copy link
Copy Markdown
Collaborator Author

Do not merge this PR until one checks if the annotations introduce any regression. See:
#49564 (comment)

@guilhermeleobas guilhermeleobas marked this pull request as draft December 31, 2020 17:29
Comment thread test/test_jit.py
Comment on lines +951 to +955
# crashes pytorch
# m3 = Mod(nn.MaxPool3d(3, stride=2, return_indices=True),
# nn.MaxUnpool3d(3, stride=2))
# inp3d = torch.randn(20, 16, 51, 33, 15)
# self.checkModule(m3, (inp3d,))
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test crashes the PyTorch jit:

terminate called after throwing an instance of 'c10::Error'
  what():  found an invalid max index 24225 (output volumes are of size 50x32x14
Exception raised from max_unpooling3d_forward_out_cpu_frame at ../aten/src/ATen/native/MaxUnpooling.cpp:196 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x68 (0x7fc38ea7da08 in /home/guilhermel/git/pytorch/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7fc38ea42804 in /home/guilhermel/git/pytorch/torch/lib/libc10.so)
frame #2: <unknown function> + 0x14d1fcd (0x7fc39c670fcd in /home/guilhermel/git/pytorch/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x86438 (0x7fc3af2b4438 in /home/guilhermel/.conda/envs/pytorch-cuda-dev/lib/libgomp.so.1)
frame #4: __kmp_invoke_microtask + 0x93 (0x7fc3af2c8c83 in /home/guilhermel/.conda/envs/pytorch-cuda-dev/lib/libgomp.so.1)
frame #5: <unknown function> + 0x3c747 (0x7fc3af26a747 in /home/guilhermel/.conda/envs/pytorch-cuda-dev/lib/libgomp.so.1)
frame #6: <unknown function> + 0x3b758 (0x7fc3af269758 in /home/guilhermel/.conda/envs/pytorch-cuda-dev/lib/libgomp.so.1)
frame #7: <unknown function> + 0x855aa (0x7fc3af2b35aa in /home/guilhermel/.conda/envs/pytorch-cuda-dev/lib/libgomp.so.1)
frame #8: <unknown function> + 0x9609 (0x7fc3cbc5e609 in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #9: clone + 0x43 (0x7fc3cbb85293 in /lib/x86_64-linux-gnu/libc.so.6)

Aborted (core dumped)

@guilhermeleobas
Copy link
Copy Markdown
Collaborator Author

I've fixed a few type annotations that were crashing the jit and added tests

@guilhermeleobas
Copy link
Copy Markdown
Collaborator Author

guilhermeleobas commented Jan 7, 2021

I guess part of this PR is blocked by #47888

Issue #45904 summarizes the error that I'm facing at the moment to add tests for nn.MaxUnpoolNd:

class Mod(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool1d(2, stride=2, return_indices=True)
        self.unpool = nn.MaxUnpool1d(2, stride=2)

    def forward(self, input):
        output, indices = self.pool(input)
        return self.unpool(output, indices)

When return_indices=True, nn.MaxPool1d.forward will return a tuple of tensors rather than a single tensor (current annotation).

cc @rgommers, @walterddr

@rgommers
Copy link
Copy Markdown
Collaborator

rgommers commented Jan 9, 2021

When return_indices=True, nn.MaxPool1d.forward will return a tuple of tensors rather than a single tensor (current annotation).

As noted in #45904 (comment), that requires adding a Literal[True] overload.

Best to wait with updating this PR until gh-47888 is landed I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: nn Related to torch.nn module: typing Related to mypy type annotations open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Enable torch.nn.modules.pooling typechecks during CI

6 participants