add channel last 3d support for batch_norm on CPU#97774
add channel last 3d support for batch_norm on CPU#97774CaoE wants to merge 18 commits intogh/CaoE/8/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97774
Note: Links to docs will display an error until the docs builds have been completed. ✅ 1 Unrelated FailureAs of commit d63212c: UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
mingfeima
left a comment
There was a problem hiding this comment.
Generally LGTM, just need to update the test cases a little bit.
test/test_nn.py
Outdated
|
|
||
| # test NC11 and N1HW; test mixed dtype | ||
| for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1)]: | ||
| for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1), (4, 8, 2, 10, 10), (4, 1, 2, 9, 9), (4, 9, 1, 1, 1)]: |
There was a problem hiding this comment.
add an argument of mod is more neat:
helper(self, nn.BatchNorm2d, shape, ...)
helper(self, nn.BatchNorm3d, shape, ...)
There was a problem hiding this comment.
Added an argument of mod.
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
ghstack-source-id: 7d858c7 Pull Request resolved: pytorch#97774
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
ghstack-source-id: 248cb00 Pull Request resolved: pytorch#97774
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
@ngimel Could you please review this PR ? Thank you. |
|
@malfet Could you please review this PR ? Thank you. |
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
@cpuhrsch Could you please review this PR ? Thank you. |
|
@CaoE - Which PR introduced the but that this is fixing? |
@cpuhrsch This PR is not fixing specific PRs. Current BN kernel can actually support |
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
Adding @mikaylagawarecki since this also touches torch.nn. @mikaylagawarecki - I think this looks good, but maybe you can double check the tests are sufficient? |
|
|
||
| def test_batchnorm_nhwc_cpu(self): | ||
| def helper(self, size, dtype, mixed_dtype=False): | ||
| def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last): |
There was a problem hiding this comment.
we actually have a test test_memory_format that is run on modules with ModuleInfo registrations in the module_db arg is set in the ModuleInfo
My understanding is that the effect it achieves is similar to that of this test except for
- the functionality of the mixed_dtype arg
- the checks on backward
but do correct me if I have missed anything here!
These tests will should be run for channels_last_3d if there is a sample input that is 5D. But curiously do not seem to be catching that BatchNorm3d (which has 5d sample inputs) does not seem to have supported channels_last_3d prior to this PR (I don't see the test being xfailed in the ModuleInfo)
Do you happen to know why this is the case?
Separately, if this is indeed replicated logic it would be nice if this logic could be unified to extend test_memory_format for modularity, but feel free to disagree here if I have missed anything here!
There was a problem hiding this comment.
the effect it achieves is similar to that of this test except for
- the functionality of the mixed_dtype arg
- the checks on backward
Yes.
BatchNorm3d (which has 5d sample inputs) does not seem to have supported channels_last_3d prior to this PR (I don't see the test being xfailed in the ModuleInfo)
Actually, BatchNorm3d can not utilize channels_last_3d kernel prior to this PR but it has channels_last_3d support using another code path by TensorIterator https://github.com/pytorch/pytorch/pull/97774/files#diff-9e36531b6ea57776251e04cf2e1b84a94145020d870fed3d7f82388c47371483L177.
There was a problem hiding this comment.
Ah I see, so this PR allows BatchNorm3d to use a fast kernel for channels_last_3d inputs without transforming the inputs, is that correct?
Would you be open to updating test_memory_format? I don't want to block this PR if the change is urgent but it would be some super nice cleanup since it seems like you are updating other functions to support channels_last_3d as well and could catch cases where backward kernels are not properly supporting channels_last_3d :)
There was a problem hiding this comment.
Yes, I'm willing to add checks on backward for test_memory_format.
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
@mikaylagawarecki I submitted a PR for adding backward check for test_memory_format #106104. Can we land PRs of this stack first ? |
There was a problem hiding this comment.
@CaoE awesome thank you! I will review the other PR. We can land this stack first, stamping to unblock
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10