Skip to content

OpInfo for nn.functional.avg_pool2d#62455

Closed
krshrimali wants to merge 4 commits intopytorch:masterfrom
krshrimali:opinfo/nn/functional/avg_pool2d
Closed

OpInfo for nn.functional.avg_pool2d#62455
krshrimali wants to merge 4 commits intopytorch:masterfrom
krshrimali:opinfo/nn/functional/avg_pool2d

Conversation

@krshrimali
Copy link
Contributor

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 30, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit a2a23d5 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

1 failure not recognized by patterns:

Job Step Action
CircleCI pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit pytorch android gradle custom build single architecture (for PR) 🔁 rerun

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@krshrimali krshrimali marked this pull request as draft July 30, 2021 05:58
@mruberry
Copy link
Collaborator

@zou3519 and @heitorschueroff would you review this?

Comment on lines +2298 to +2300
cases = (((1, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2),
((1, 1, 4, 4), (2, 2), (1, 1), (0, ), False, True, -2),
((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some more cases.

  • Could we have one that is just a single kernel size ((1, 3, 9, 9),)?
  • avgpool_2d advertises that stride, padding, and kernel_size can be a single number, so we should test something like ((1, 3, 9, 9), 3, 1, 1)
  • Could we have a case where the two strides are not the same and another where the two paddings are not the same?
  • We aren't testing division_override=None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look @zou3519! I've addressed the suggestions in the latest commit. :)

def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

cases = (((1, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment describing what each element represents please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for the suggestion @heitorschueroff!

((1, 1, 4, 4), (2, 2), (1, 1), (0, ), False, True, -2),
((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, 3))

return [SampleInput(make_arg(input_shape), args=(kernel_size, stride, padding, ceil_mode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you do this with a regular for loop instead of a list comprehension for readability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like list comprehensions. I think whether or not to do a list comprehension or do a for-loop should not be something we suggest in reviews because it's a matter of personal taste and there is no official PyTorch Style Guide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for the suggestion, @heitorschueroff!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I saw your comment just now @zou3519! Since I had to add a single case without any args except kernel size as well, so I kept a generator() approach for now. Hope that's fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like list comprehensions. I think whether or not to do a list comprehension or do a for-loop should not be something we suggest in reviews because it's a matter of personal taste and there is no official PyTorch Style Guide.

List comprehensions are great for short one-liners, but I also agree that it is a matter of preferred style and since we don't enforce a style guide it is ok to ignore this suggestion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either style works, so this is fine :). I wouldn't say that one is nicer than the other because it depends on who is reading the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the valuable inputs, @zou3519 and @heitorschueroff! I appreciate it. :)

@krshrimali krshrimali marked this pull request as ready for review July 30, 2021 17:06
Copy link
Contributor

@heitorschueroff heitorschueroff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krshrimali This is looking great, thank you for the contribution! I think there is room to explore a little more with the test cases, for instance, it looks like all test cases have size 1 for batch dimension. However, as a first PR adding OpInfo this is already good. Let's wait for the test to run. Thanks!

@krshrimali krshrimali added the module: testing Issues related to the torch.testing module (not tests) label Jul 30, 2021
@krshrimali
Copy link
Contributor Author

@krshrimali This is looking great, thank you for the contribution! I think there is room to explore a little more with the test cases, for instance, it looks like all test cases have size 1 for batch dimension. However, as a first PR adding OpInfo this is already good. Let's wait for the test to run. Thanks!

Thanks, @heitorschueroff! I had two more cases in mind which I added in the recent commit:

  • Since stride has a default value of kernel_size, passing an empty tuple (for stride) has been added in the cases. None can not be passed as JIT test will fail ( Expected a value of type 'List[int]' for argument 'stride' but instead found type 'NoneType'.).
  • Multiple batch size, as also suggested by you. I think it's good for completion. (used 2 as the batch size, anything greater would be just extra load on the testing time, I thought).

cc: @zou3519 as well.

Hope this looks complete now. Thank you!

@krshrimali
Copy link
Contributor Author

Gentle ping: @zou3519 and @heitorschueroff. The failures seem irrelevant to this PR, please let me know if there is anything else required here.

Thanks for your reviews and help.

@facebook-github-bot
Copy link
Contributor

@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@heitorschueroff
Copy link
Contributor

Gentle ping: @zou3519 and @heitorschueroff. The failures seem irrelevant to this PR, please let me know if there is anything else required here.

Thanks for your reviews and help.

LGTM! We appreciate your contribution.

@facebook-github-bot
Copy link
Contributor

@heitorschueroff merged this pull request in 5dbcd56.

@ngimel
Copy link
Collaborator

ngimel commented Aug 6, 2021

Reverting, causes flakiness (adaptive_avg_pool backward is nondeterministic)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: testing Issues related to the torch.testing module (not tests) open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants