Skip to content

fix type promotion for group_norm composite C++ kernel#86607

Closed
bdhirsh wants to merge 8 commits intogh/bdhirsh/321/basefrom
gh/bdhirsh/321/head
Closed

fix type promotion for group_norm composite C++ kernel#86607
bdhirsh wants to merge 8 commits intogh/bdhirsh/321/basefrom
gh/bdhirsh/321/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Oct 10, 2022

python decomp for native_group_norm is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86607

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 1 Pending

As of commit 85d4545:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 10, 2022
python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.





[ghstack-poisoned]
python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.





[ghstack-poisoned]
python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.





[ghstack-poisoned]
python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.





[ghstack-poisoned]
python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.





[ghstack-poisoned]
python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.





[ghstack-poisoned]
python decomp for `native_group_norm` is correct in more cases than the C++ composite. Updating the tests to fail properly in this case was more annoying than just fixing the C++ decomp, so I fixed it here.

When the input tensor had a dtype with less precision than float32, the C++ decomp would unconditionally set the mean/variance to float32, which was wrong.





[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants