Skip to content

Quantization aware training: Freeze batch norm support#26624

Closed
raghuramank100 wants to merge 11 commits intogh/raghuramank100/35/basefrom
gh/raghuramank100/35/head
Closed

Quantization aware training: Freeze batch norm support#26624
raghuramank100 wants to merge 11 commits intogh/raghuramank100/35/basefrom
gh/raghuramank100/35/head

Conversation

@raghuramank100
Copy link
Contributor

@raghuramank100 raghuramank100 commented Sep 22, 2019

Stack from ghstack:

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: D17512199

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
return super(ConvReLU2d, cls).from_float(mod, qconfig)

def update_bn_stats(mod):
if type(mod) in set([ConvBnReLU2d,ConvBn2d]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
@dzhulgakov dzhulgakov requested a review from gchanan September 24, 2019 05:29
Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question, probably not for this diff - do we need freeze_bn support also in non-fused modules?

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
@raghuramank100 raghuramank100 added this to the 1.3 milestone Sep 27, 2019
@raghuramank100
Copy link
Contributor Author

General question, probably not for this diff - do we need freeze_bn support also in non-fused modules?

Setting bn to eval() would do that and allow for frozen statistics to be used during training, but this is ugly as we want the rest of the modules to be in train(), with bn alone being in eval.
I think we need this supported in bn so that we have an independent way to control this outside of .train() and .eval(). An example for this need is the frozenbatchnorm module in torchvision: https://github.com/pytorch/vision/blob/master/torchvision/ops/misc.py#L135

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
return super(ConvReLU2d, cls).from_float(mod, qconfig)

def update_bn_stats(mod):
if type(mod) in set([ConvBnReLU2d, ConvBn2d]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also just do 'hasattr' (in case more modules appear in the future)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safer? We want to modify this field only for specific modules. Tagging based on a specific attribute could be theoretically more risky as we would set freeze_bn to false for any module that has this field. Leaving this as is so that we have explicit control over which modules we apply this modification to.

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 84ee8ac.

jamesr66a pushed a commit that referenced this pull request Oct 3, 2019
Summary:
Pull Request resolved: #26624

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
ghstack-source-id: 91008297

Test Plan: buck test caffe2/test:quantization -- --print-passing-details

Differential Revision: D17512199

fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
jamesr66a pushed a commit that referenced this pull request Oct 3, 2019
Summary:
Pull Request resolved: #26624

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
ghstack-source-id: 91008297

Test Plan: buck test caffe2/test:quantization -- --print-passing-details

Differential Revision: D17512199

fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
jamesr66a pushed a commit that referenced this pull request Oct 4, 2019
Summary:
Pull Request resolved: #26624

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
ghstack-source-id: 91008297

Test Plan: buck test caffe2/test:quantization -- --print-passing-details

Differential Revision: D17512199

fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
jamesr66a pushed a commit that referenced this pull request Oct 4, 2019
Summary:
Pull Request resolved: #26624

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
ghstack-source-id: 91008297

Test Plan: buck test caffe2/test:quantization -- --print-passing-details

Differential Revision: D17512199

fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
soumith pushed a commit that referenced this pull request Oct 7, 2019
Summary:
Pull Request resolved: #26624

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
ghstack-source-id: 91008297

Test Plan: buck test caffe2/test:quantization -- --print-passing-details

Differential Revision: D17512199

fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
@facebook-github-bot facebook-github-bot deleted the gh/raghuramank100/35/head branch October 28, 2019 22:19
pdlive215 pushed a commit to pdlive215/pytorch that referenced this pull request Nov 27, 2019
Summary:
Pull Request resolved: pytorch#26624

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
ghstack-source-id: 91008297

Test Plan: buck test caffe2/test:quantization -- --print-passing-details

Differential Revision: D17512199

fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
xxtEchjovs44 pushed a commit to xxtEchjovs44/pytorch that referenced this pull request Jan 29, 2020
Pull Request resolved: pytorch/pytorch#26624

For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
ghstack-source-id: 90704775

Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants