Skip to content

add channel last 3d support for batch_norm on CPU#97774

Closed
CaoE wants to merge 18 commits intogh/CaoE/8/basefrom
gh/CaoE/8/head
Closed

add channel last 3d support for batch_norm on CPU#97774
CaoE wants to merge 18 commits intogh/CaoE/8/basefrom
gh/CaoE/8/head

Conversation

@CaoE
Copy link
Collaborator

@CaoE CaoE commented Mar 28, 2023

@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Mar 28, 2023
@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Mar 28, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97774

Note: Links to docs will display an error until the docs builds have been completed.

✅ 1 Unrelated Failure

As of commit d63212c:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@CaoE CaoE marked this pull request as draft March 28, 2023 13:48
@CaoE CaoE added the topic: not user facing topic category label Mar 29, 2023
CaoE added 3 commits March 28, 2023 20:43
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE CaoE requested a review from mingfeima April 7, 2023 02:12
@CaoE CaoE added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 7, 2023
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, just need to update the test cases a little bit.

test/test_nn.py Outdated

# test NC11 and N1HW; test mixed dtype
for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1)]:
for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1), (4, 8, 2, 10, 10), (4, 1, 2, 9, 9), (4, 9, 1, 1, 1)]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an argument of mod is more neat:

helper(self, nn.BatchNorm2d, shape, ...)
helper(self, nn.BatchNorm3d, shape, ...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an argument of mod.

@CaoE CaoE changed the title fix channel last 3d support for batch_norm fix channel last 3d support for batch_norm on CPU Apr 13, 2023
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE CaoE requested a review from jgong5 April 13, 2023 06:39
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE CaoE marked this pull request as ready for review May 9, 2023 07:53
@CaoE CaoE added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label May 10, 2023
CaoE added a commit to CaoE/pytorch that referenced this pull request May 10, 2023
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
CaoE added a commit to CaoE/pytorch that referenced this pull request May 15, 2023
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE
Copy link
Collaborator Author

CaoE commented Jul 3, 2023

@ngimel Could you please review this PR ? Thank you.

@CaoE
Copy link
Collaborator Author

CaoE commented Jul 3, 2023

@malfet Could you please review this PR ? Thank you.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE CaoE requested a review from cpuhrsch July 4, 2023 03:26
@CaoE
Copy link
Collaborator Author

CaoE commented Jul 4, 2023

@cpuhrsch Could you please review this PR ? Thank you.

@cpuhrsch
Copy link
Contributor

@CaoE - Which PR introduced the but that this is fixing?

@CaoE CaoE changed the title fix channel last 3d support for batch_norm on CPU Add channel last 3d support for batch_norm on CPU Jul 11, 2023
@CaoE
Copy link
Collaborator Author

CaoE commented Jul 11, 2023

@CaoE - Which PR introduced the but that this is fixing?

@cpuhrsch This PR is not fixing specific PRs. Current BN kernel can actually support ChannelsLast3d, but ChannelsLast3d is not supported because the condition do not include ChannelsLast3d. The title of this PR may be a bit confusing, I have modified it.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE CaoE changed the title Add channel last 3d support for batch_norm on CPU add channel last 3d support for batch_norm on CPU Jul 11, 2023
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@cpuhrsch
Copy link
Contributor

Adding @mikaylagawarecki since this also touches torch.nn.

@mikaylagawarecki - I think this looks good, but maybe you can double check the tests are sufficient?


def test_batchnorm_nhwc_cpu(self):
def helper(self, size, dtype, mixed_dtype=False):
def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last):
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually have a test test_memory_format that is run on modules with ModuleInfo registrations in the module_db arg is set in the ModuleInfo

My understanding is that the effect it achieves is similar to that of this test except for

  1. the functionality of the mixed_dtype arg
  2. the checks on backward

but do correct me if I have missed anything here!

These tests will should be run for channels_last_3d if there is a sample input that is 5D. But curiously do not seem to be catching that BatchNorm3d (which has 5d sample inputs) does not seem to have supported channels_last_3d prior to this PR (I don't see the test being xfailed in the ModuleInfo)

Do you happen to know why this is the case?

Separately, if this is indeed replicated logic it would be nice if this logic could be unified to extend test_memory_format for modularity, but feel free to disagree here if I have missed anything here!

Copy link
Collaborator Author

@CaoE CaoE Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the effect it achieves is similar to that of this test except for

  1. the functionality of the mixed_dtype arg
  2. the checks on backward

Yes.

BatchNorm3d (which has 5d sample inputs) does not seem to have supported channels_last_3d prior to this PR (I don't see the test being xfailed in the ModuleInfo)

Actually, BatchNorm3d can not utilize channels_last_3d kernel prior to this PR but it has channels_last_3d support using another code path by TensorIterator https://github.com/pytorch/pytorch/pull/97774/files#diff-9e36531b6ea57776251e04cf2e1b84a94145020d870fed3d7f82388c47371483L177.

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, so this PR allows BatchNorm3d to use a fast kernel for channels_last_3d inputs without transforming the inputs, is that correct?

Would you be open to updating test_memory_format? I don't want to block this PR if the change is urgent but it would be some super nice cleanup since it seems like you are updating other functions to support channels_last_3d as well and could catch cases where backward kernels are not properly supporting channels_last_3d :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm willing to add checks on backward for test_memory_format.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE
Copy link
Collaborator Author

CaoE commented Aug 1, 2023

@mikaylagawarecki I submitted a PR for adding backward check for test_memory_format #106104. Can we land PRs of this stack first ?

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CaoE awesome thank you! I will review the other PR. We can land this stack first, stamping to unblock

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
@CaoE
Copy link
Collaborator Author

CaoE commented Aug 3, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: nn release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants