You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A model that uses SyncBN cannot support CPU inference. SyncBN can also cause some other issues during inference. We introduce revert_sync_batchnorm from @kapily's work (pytorch/pytorch#41081 (comment)), which can convert SyncBN in any model to BN.
Modification
Added revert_sync_batchnorm to mmcv/cnn/utils/sync_bn.py and its unittest.
BC-breaking (Optional)
No, but there could be a potential minor risk -
PyTorch provides an implementation of convert_sync_batchnorm which converts BatchNorm1D, BatchNorm2D and BatchNorm3D to SyncBatchNorm. However, it doesn't provide an inverse function for that. The reason is SyncBatchNorm neither has a strict input dimension checking nor stores the expected input dimension, whereas BatchNormxD strictly validates the input dimension (and this is the only difference between BatchNorm1D, 2D, and 3D). Therefore, if one converts BNxD to SyncBN using PyTorch's implementation and then converts it back to BN using this implementation, the input dimension check is no longer retained.
Merging #1253 (35df788) into master (9341856) will increase coverage by 0.69%.
The diff coverage is 63.05%.
❗ Current head 35df788 differs from pull request most recent head 7f85e59. Consider uploading reports for the commit 7f85e59 to get more accurate results
@ZwwWayne A model with MMSyncBN layer cannot be built in a non-distributed environment. I think the only solution is to re-generate a Config replacing MMSyncBN with BN and use it to build the model, but the idea can be very different from the current implementation. I'm also not sure if the checkpoint is still compatible after such a conversion. BTW, what was the goal of MMSyncBN?
@ZwwWayne A model with MMSyncBN layer cannot be built in a non-distributed environment. I think the only solution is to re-generate a Config replacing MMSyncBN with BN and use it to build the model, but the idea can be very different from the current implementation. I'm also not sure if the checkpoint is still compatible after such a conversion. BTW, what was the goal of MMSyncBN?
It was written a long time ago when PyTorch does not provide SyncBN. It also supports some corner case that currently is not supported in PyTorch.
I suggest implementing a function to simply modify the config to change SyncBN/MMSyncBN back to BN. It works for both SyncBN and MMSyncBN and meets the demands that this PR wants to achieve. It is a little bit late to convert the model back after building SyncBN layer.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
A model that uses SyncBN cannot support CPU inference. SyncBN can also cause some other issues during inference. We introduce
revert_sync_batchnormfrom @kapily's work (pytorch/pytorch#41081 (comment)), which can convert SyncBN in any model to BN.Modification
Added
revert_sync_batchnormtommcv/cnn/utils/sync_bn.pyand its unittest.BC-breaking (Optional)
No, but there could be a potential minor risk -
PyTorch provides an implementation of convert_sync_batchnorm which converts BatchNorm1D, BatchNorm2D and BatchNorm3D to SyncBatchNorm. However, it doesn't provide an inverse function for that. The reason is SyncBatchNorm neither has a strict input dimension checking nor stores the expected input dimension, whereas BatchNormxD strictly validates the input dimension (and this is the only difference between BatchNorm1D, 2D, and 3D). Therefore, if one converts BNxD to SyncBN using PyTorch's implementation and then converts it back to BN using this implementation, the input dimension check is no longer retained.