Added update_parameters to EMA to fix calculation#4406
Added update_parameters to EMA to fix calculation#4406prabhat00155 merged 1 commit intopytorch:mainfrom
Conversation
datumbox
left a comment
There was a problem hiding this comment.
LGTM, thanks Prabhat.
As @kazhang this is an issue on PyTorch core. Ideally the base class should include all state not just the params. We use a similar approach in other averaging schemes (see this) so this is aligned to what we've seen in the past. Though this workaround will do for our use-case, I think it's still worth raising it on PyTorch core and either correcting it or introducing better control over which params we should consider.
Makes sense. Will follow-up on this. |
…65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: #65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
Reviewed By: datumbox Differential Revision: D31268055 fbshipit-source-id: 2bedf7cd5db0a345dffa42a9ff94ce7d425e1008
…ytorch#65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: pytorch#65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2 (cherry picked from commit 2ea724b)
…65495) (#65755) * Added option to update parameters using state_dict in AveragedModel (#65495) Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: #65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2 (cherry picked from commit 2ea724b) * Added validation of mode parameter in AveragedModel (#65921) Summary: Discussion: #65495 (comment) Pull Request resolved: #65921 Reviewed By: albanD Differential Revision: D31310105 Pulled By: prabhat00155 fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3 (cherry picked from commit c7748fc)
Resolves #4391.
Output logs:
resnext101_32x8d_training_log_ema_fix.txt
resnet18_training_log_ema_fix.txt
cc @datumbox