[primTorch] Implement group norm reference#87054
[primTorch] Implement group norm reference#87054rdspring1 wants to merge 27 commits intopytorch:masterfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87054
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c362e57: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ) | ||
| ), | ||
| PythonRefInfo( | ||
| "_refs.nn.functional.group_norm", |
There was a problem hiding this comment.
How thorough are the existing nn.functional.group_norm sample or reference inputs to validate this consistency? Are interesting edge cases covered? Do ErrorInputs need to be added?
mruberry
left a comment
There was a problem hiding this comment.
Exciting start, @rdspring1! I made some notes for your review. Looking forward to hearing your thoughts
|
@mruberry I updated the sample inputs to check the permutations of weights and biases and added some basic error inputs. |
|
@mruberry I updated the test cases for |
|
Just a few small comments inline for your review, @rdspring1! |
|
@mruberry I addressed your comments. |
| mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] | ||
| rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] | ||
|
|
||
| mean = prims._squeeze_aten(mean, reduction_dims) |
There was a problem hiding this comment.
Why not just call prims.squeeze?
There was a problem hiding this comment.
@mruberry I'm confused why I couldn't use prims.expand_dims to replace torch.unsqueeze but I can replace torch.squeeze with prims.squeeze.
I thought it was to avoid calling the prims implementation directly. prims._squeeze_aten only translates to torch functions.
There was a problem hiding this comment.
The thinking is:
- prefer using torch.* functions
- use a ref if there isn't a torch function for what you want
- use a prim if there isn't a ref for what you want
So it's just about the order of preferences. If torch.unsqueeze is a pain to use and prims.expand_dims has the functionality you're looking for then you can go ahead and use it
mruberry
left a comment
There was a problem hiding this comment.
Cool! But see inline comment about calling prims.squeeze directly
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Add group norm reference Split from pytorch#81191 Pull Request resolved: pytorch#87054 Approved by: https://github.com/mruberry
Add group norm reference Split from pytorch#81191 Pull Request resolved: pytorch#87054 Approved by: https://github.com/mruberry
Add group norm reference
Split from #81191