Skip to content

[primTorch] Implement group norm reference#87054

Closed
rdspring1 wants to merge 27 commits intopytorch:masterfrom
rdspring1:ref_group_norm
Closed

[primTorch] Implement group norm reference#87054
rdspring1 wants to merge 27 commits intopytorch:masterfrom
rdspring1:ref_group_norm

Conversation

@rdspring1
Copy link
Copy Markdown
Contributor

Add group norm reference
Split from #81191

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Oct 17, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87054

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c362e57:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@rdspring1 rdspring1 marked this pull request as ready for review October 17, 2022 06:18
Comment thread torch/_refs/__init__.py Outdated
Comment thread torch/_refs/__init__.py Outdated
Comment thread torch/_refs/nn/functional/__init__.py Outdated
)
),
PythonRefInfo(
"_refs.nn.functional.group_norm",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How thorough are the existing nn.functional.group_norm sample or reference inputs to validate this consistency? Are interesting edge cases covered? Do ErrorInputs need to be added?

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exciting start, @rdspring1! I made some notes for your review. Looking forward to hearing your thoughts

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 18, 2022
@rdspring1
Copy link
Copy Markdown
Contributor Author

@mruberry I updated the sample inputs to check the permutations of weights and biases and added some basic error inputs.

Comment thread torch/_prims/__init__.py Outdated
Comment thread torch/_prims/__init__.py Outdated
Comment thread torch/_refs/__init__.py Outdated
Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
Comment thread torch/nn/functional.py
Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
@rdspring1
Copy link
Copy Markdown
Contributor Author

@mruberry I updated the test cases for group_norm.

Comment thread torch/_refs/__init__.py Outdated
Comment thread torch/_refs/__init__.py Outdated
Comment thread torch/_refs/nn/functional/__init__.py Outdated
Comment thread torch/_refs/nn/functional/__init__.py
Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Nov 7, 2022

Just a few small comments inline for your review, @rdspring1!

@rdspring1
Copy link
Copy Markdown
Contributor Author

@mruberry I addressed your comments.

Comment thread torch/_refs/__init__.py Outdated
mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]

mean = prims._squeeze_aten(mean, reduction_dims)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just call prims.squeeze?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry I'm confused why I couldn't use prims.expand_dims to replace torch.unsqueeze but I can replace torch.squeeze with prims.squeeze.

I thought it was to avoid calling the prims implementation directly. prims._squeeze_aten only translates to torch functions.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thinking is:

  • prefer using torch.* functions
  • use a ref if there isn't a torch function for what you want
  • use a prim if there isn't a ref for what you want

So it's just about the order of preferences. If torch.unsqueeze is a pain to use and prims.expand_dims has the functionality you're looking for then you can go ahead and use it

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! But see inline comment about calling prims.squeeze directly

@rdspring1
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 10, 2022
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Add group norm reference
Split from pytorch#81191
Pull Request resolved: pytorch#87054
Approved by: https://github.com/mruberry
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Add group norm reference
Split from pytorch#81191
Pull Request resolved: pytorch#87054
Approved by: https://github.com/mruberry
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants