Fix SpectralNorm with DataParallel#12671
Conversation
|
Autograd in eval mode still has problems, but I decide to fix that in a later PR due to BC complications. |
|
@ssnl I read thru this and seems solid! 👍 |
|
Looks all reasonable to me, but I lack the Distributed expertise for that to mean much. Out of curiosity: What is the BC breaking when you recompute the weight in eval mode instead of detaching? |
|
@crcrpar @t-vi Thanks for looking! @t-vi The problem with recomputing eval mode is that we only store |
facebook-github-bot
left a comment
There was a problem hiding this comment.
SsnL is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
I updated my code with the latest spectral_norm implementation (I just replaced spectral_norm.py and function.py), but I got the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation This error disappeared if I switch back to the old spectral_norm implementation. @ssnl |
|
@YaoshengFu I can't reproduce the error you see. Could you install the nightly and check if the error still happens? |
|
I have re-installed the latest version of pytorch from source and it still has the same error. I have tried it on different projects and I didn't see difference. For example, you can try to run code from this repo: https://github.com/rosinality/sagan-pytorch Just replace the spectral_norm implementation in model.py or model_resnet.py (which seems to be replicated from the previous official implementation as well) with the official one and run it, you should be able to see the same error (at least I did). |
|
@YaoshengFu Thanks. I will do! |
Summary: Problems with SN and DP after #12671 : 1. in eval mode, `weight_orig` is not getting correct gradient #12737 . Fix: keep `v` vector around as a buffer and always calculate `W = W_orig / (u @ W_orig @ v)` even in eval. 2. in training mode, the `weight` buffer of the parallelized module is never updated, if someone touches `weight_orig` and/or `weight` and makes them not sharing storage. So in `eval` the weight used is wrong. Fix: Make `weight` not a buffer anymore and always calculate it as above. 3. #12671 changed SN to update `u` in-place to make DP work correctly, but then it breaks backward through two forwards (e.g., the common GAN loss `D(real) - D(fake)`) because the vectors needed to backprop the 1st forward is changed in the 2nd forward. Fix: This PR clones `u` and `v` before using them. To maintain BC, I added a hook interface for producing and loading state_dict. This is ugly and we should really have better interface for spectral_norm. But for the purpose to fix this issue, I make this patch. Even if we have a better interface, BC mechanism for legacy loading legacy state_dict still needs to be done. cc The controller you requested could not be found. crcrpar Pull Request resolved: #13350 Differential Revision: D12931044 Pulled By: SsnL fbshipit-source-id: 8be6f934eaa62414d76d2c644dedd7e1b7eb31ef
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fe8d914 Pull Request resolved: #37032
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 48d40f4 Pull Request resolved: #37032
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing ``` python test/test_quantization.py TestFakeQuantizePerTensor.test_fake_quant_control ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: cc8aa30 Pull Request resolved: #37032
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. TODO: #32684 needs to land before we can fix the graph mode test failures on this PR. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21206454](https://our.internmc.facebook.com/intern/diff/D21206454) [ghstack-poisoned]
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates all observers and fake_quant modules to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. TODO: #32684 and #37185 needs to land before we can fix the graph mode test failures on this PR. Test Plan: Script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 Added tests to relevant quant modules to prevent regressions Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21206454](https://our.internmc.facebook.com/intern/diff/D21206454) [ghstack-poisoned]
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing ``` python test/test_quantization.py TestFakeQuantizePerTensor.test_fake_quant_control ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2430f5d Pull Request resolved: #37032
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates all observers and fake_quant modules to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Depends on #32684 and #37185. Test Plan: Script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 Added integration and unit tests to cover: ``` python test/test_quantization.py TestDistributed ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21206454](https://our.internmc.facebook.com/intern/diff/D21206454) [ghstack-poisoned]
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing ``` python test/test_quantization.py TestFakeQuantizePerTensor.test_fake_quant_control ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 075a725 Pull Request resolved: #37032
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates all observers and fake_quant modules to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Depends on #32684 and #37185. Test Plan: Script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 Added integration and unit tests to cover: ``` python test/test_quantization.py TestDistributed ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D21206454](https://our.internmc.facebook.com/intern/diff/D21206454) [ghstack-poisoned]
Summary: DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing ``` python test/test_quantization.py TestFakeQuantizePerTensor.test_fake_quant_control ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2e3f88e Pull Request resolved: #37032
Summary: Pull Request resolved: #37032 DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see #12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. TODO: #32684 needs to land before we can fix the graph mode test failures on this PR. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing Imported from OSS Differential Revision: D21206454 fbshipit-source-id: df6b4b04d0ae0f7ef582c82d81418163019e96f7
Summary: There were two problems with SN + DP: 1. In SN, the updated _u vector is saved back to module via a `setattr`. However, in DP, everything is run on a replica, so those updates are lost. 2. In DP, the buffers are broadcast via a `broadcast_coalesced`, so on replicas they are all views. Therefore, the `detach_` call won't work. Fixes are: 1. Update _u vector in-place so, by the shared storage between 1st replica and the parallelized module, the update is retained 2. Do not call `detach_`. 3. Added comments in SN about the subtlety. 4. Added a note to the DP doc on this particular behavior of DP. cc crcrpar taesung89 The controller you requested could not be found. yaoshengfu Fixes pytorch#11476 Pull Request resolved: pytorch#12671 Differential Revision: D10410232 Pulled By: SsnL fbshipit-source-id: c447951844a30366d8c196bf9436340e88f3b6d9
Summary: Problems with SN and DP after pytorch#12671 : 1. in eval mode, `weight_orig` is not getting correct gradient pytorch#12737 . Fix: keep `v` vector around as a buffer and always calculate `W = W_orig / (u @ W_orig @ v)` even in eval. 2. in training mode, the `weight` buffer of the parallelized module is never updated, if someone touches `weight_orig` and/or `weight` and makes them not sharing storage. So in `eval` the weight used is wrong. Fix: Make `weight` not a buffer anymore and always calculate it as above. 3. pytorch#12671 changed SN to update `u` in-place to make DP work correctly, but then it breaks backward through two forwards (e.g., the common GAN loss `D(real) - D(fake)`) because the vectors needed to backprop the 1st forward is changed in the 2nd forward. Fix: This PR clones `u` and `v` before using them. To maintain BC, I added a hook interface for producing and loading state_dict. This is ugly and we should really have better interface for spectral_norm. But for the purpose to fix this issue, I make this patch. Even if we have a better interface, BC mechanism for legacy loading legacy state_dict still needs to be done. cc The controller you requested could not be found. crcrpar Pull Request resolved: pytorch#13350 Differential Revision: D12931044 Pulled By: SsnL fbshipit-source-id: 8be6f934eaa62414d76d2c644dedd7e1b7eb31ef
Summary: Pull Request resolved: pytorch#37032 DataParallel requires all params and buffers of child modules to be updated in place because of how it implements model replication during the forward pass (see pytorch#12671 for context). Any params or buffers not updated in place are lost and not propagated back to the master. This diff updates (some quantized modules) (TBD: all quantized modules? determine a good cut point) to do their parameter update in-place. This will enable static quant and QAT to work correctly with DataParallel. TODO: pytorch#32684 needs to land before we can fix the graph mode test failures on this PR. Test Plan: script failed before and passes after the diff: https://gist.github.com/vkuzo/78b06c01f23f98ee2aaaeb37e55f8d40 TODO before land: add integration testing Imported from OSS Differential Revision: D21206454 fbshipit-source-id: df6b4b04d0ae0f7ef582c82d81418163019e96f7
There were two problems with SN + DP:
setattr. However, in DP, everything is run on a replica, so those updates are lost.broadcast_coalesced, so on replicas they are all views. Therefore, thedetach_call won't work.Fixes are:
detach_.cc @crcrpar @taesung89 @t-vi @YaoshengFu
Fixes #11476