Quantization aware training: Freeze batch norm support#26624
Quantization aware training: Freeze batch norm support#26624raghuramank100 wants to merge 11 commits intogh/raghuramank100/35/basefrom
Conversation
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
| return super(ConvReLU2d, cls).from_float(mod, qconfig) | ||
|
|
||
| def update_bn_stats(mod): | ||
| if type(mod) in set([ConvBnReLU2d,ConvBn2d]): |
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
dzhulgakov
left a comment
There was a problem hiding this comment.
General question, probably not for this diff - do we need freeze_bn support also in non-fused modules?
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
Setting bn to eval() would do that and allow for frozen statistics to be used during training, but this is ugly as we want the rest of the modules to be in train(), with bn alone being in eval. |
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
| return super(ConvReLU2d, cls).from_float(mod, qconfig) | ||
|
|
||
| def update_bn_stats(mod): | ||
| if type(mod) in set([ConvBnReLU2d, ConvBn2d]): |
There was a problem hiding this comment.
you can also just do 'hasattr' (in case more modules appear in the future)
There was a problem hiding this comment.
Is this safer? We want to modify this field only for specific modules. Tagging based on a specific attribute could be theoretically more risky as we would set freeze_bn to false for any module that has this field. Leaving this as is so that we have explicit control over which modules we apply this modification to.
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/) [ghstack-poisoned]
|
This pull request has been merged in 84ee8ac. |
Summary: Pull Request resolved: #26624 For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training ghstack-source-id: 91008297 Test Plan: buck test caffe2/test:quantization -- --print-passing-details Differential Revision: D17512199 fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
Summary: Pull Request resolved: #26624 For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training ghstack-source-id: 91008297 Test Plan: buck test caffe2/test:quantization -- --print-passing-details Differential Revision: D17512199 fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
Summary: Pull Request resolved: #26624 For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training ghstack-source-id: 91008297 Test Plan: buck test caffe2/test:quantization -- --print-passing-details Differential Revision: D17512199 fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
Summary: Pull Request resolved: #26624 For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training ghstack-source-id: 91008297 Test Plan: buck test caffe2/test:quantization -- --print-passing-details Differential Revision: D17512199 fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
Summary: Pull Request resolved: #26624 For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training ghstack-source-id: 91008297 Test Plan: buck test caffe2/test:quantization -- --print-passing-details Differential Revision: D17512199 fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
Summary: Pull Request resolved: pytorch#26624 For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training ghstack-source-id: 91008297 Test Plan: buck test caffe2/test:quantization -- --print-passing-details Differential Revision: D17512199 fbshipit-source-id: f7b981e2b1966ab01c4dbb161030177274a998b6
Pull Request resolved: pytorch/pytorch#26624 For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training ghstack-source-id: 90704775 Differential Revision: [D17512199](https://our.internmc.facebook.com/intern/diff/D17512199/)
Stack from ghstack:
For QAT we need to be able to control batch norm for all modules from the top. Adding helper functions to enable/disable batch norm freezing during training
Differential Revision: D17512199