Generic weight averaging callback that supports EMA#20545
Generic weight averaging callback that supports EMA#20545Borda merged 25 commits intoLightning-AI:masterfrom
Conversation
|
Hey @senarvi, this looks great! I saw you already added support for saving and resuming which is great. There are many scenarios there (save every n steps, time-based, every epoch, etc) let's make sure we cover them all (for inspiration, we added quite a few tests here #20379)
No I think it's better to have one with configurable averaging flags, more lightning-esque
I think this is ok, but my doubt with forcing Wdyt about this? I don't necessarily want to make the implementation more complex, so this is just for discussion.
It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume. Regarding removing the StochasticWeightAveraging callback, I don't necessarily see that happening. We have a pretty strong commitment to backward compatibility at this point, so keeping that in with a notice to just use this one will not hurt. |
That's a good point. I don't know what would be a good solution.
That's an interesting idea. We could have the user pass a function It seems that AveragedModel will copy the current model parameters when called the first time, and update the average on subsequent calls. This means that the first average is computed when I checked how StochasticWeightAveraging does this and I think it doesn't work correctly. It only ever updates the average model parameters in on_train_epoch_start(), so the average is not updated after the last epoch. Just shows why I'd like to keep the logic as simple as possible. |
|
Hi, I have a couple questions.
|
|
Hi @cyanic-selkie During training (stage=fit), the actual LightningModule is what we update using the optimizer (I call it the current model) and an AveragedModel is maintained in the background (I call it the average model). I assume that validation is only called during training. Before and after validation we swap the current model and the average model, so the average model will be validated. When saving a checkpoint, we save the average model parameters in the state_dict. So if you later load the checkpoint without WeightAveraging callback and run a test or export to ONNX, you will be using the average parameters. When training ends, we copy the average model parameters to the current model. So if you run a test or export to ONNX after training, you will be using the average parameters. That's the idea at least. I'm not confident that I have thought about every possible corner case. It would be great if you could test that it works in your case. |
|
@senarvi Ah! Thanks for the clarification, I should've checked the code out more carefully. I tried your branch out on a quantization aware training enabled model with ONNX export at the end and everything is working beautifully! I hope this gets merged quickly. |
efc77dc to
0010492
Compare
|
The user can now provide either the For example: update_on_step = lambda x: x > 100or update_on_epoch = lambda x: x in (3, 5, 7)Using I tested EMA in an actual learning task and it gave an improvement, so I'm starting to be more confident that this works. I think the biggest question that is still left is whether it's a problem that we force
@tchaton I think you contributed the |
* A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs. * The user can provide a callback that defines after which steps or epochs the average model is updated.
5f34205 to
c8d50bd
Compare
|
Is there anything blocking this from being merged? |
|
I marked this ready for review. There were no comments whether it's a problem that we force |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20545 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 268 266 -2
Lines 23449 23488 +39
=========================================
- Hits 20389 18475 -1914
- Misses 3060 5013 +1953 🚀 New features to boost your workflow:
|
|
BTW: I think it's totally fine to merge this as is and open an issue to gather discussions about averaging buffers. The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. |
There's a |
c6856eb to
42d91cd
Compare
|
Hi! Thanks for this great PR. The current implementation only leverages |
I think we could just pass |
- The user can specify when to update the average model by overriding the should_update() method - Any keyword arguments will be passed to the AveragedModel constructor
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
for more information, see https://pre-commit.ci
|
@Borda should be good now |
| if is_overridden("configure_model", pl_module): | ||
| rank_zero_warn( | ||
| "You're using the WeightAveraging callback with a model that overrides the configure_model " | ||
| "callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory." |
There was a problem hiding this comment.
What changes do we need to support sharded models like when using deepspeed/FSDP?
There was a problem hiding this comment.
@npuichigo this pull request has been merged. If you're interested in FSDP, you should open a new issue. I think the average model parameters should be sharded too, and then gathered in the end for saving to disk. Someone who has developed the FSDP code or at least uses it, knows the best way for doing that and can test how this affects the performance.
There was a problem hiding this comment.
@senarvi, I believe SimpleFold was trained with Lightning + FSDP using PyTorch's AveragedModel wrapper, with all_gathers of parameters handled here: https://github.com/apple/ml-simplefold/blob/ff4b91daca2ef8cafe83e3e80140bcce6a3136d1/src/simplefold/model/simplefold.py#L718. Would you be interested in making a small new PR to support FSDP like this?
There was a problem hiding this comment.
Thanks @amorehead . That solution seems simple and I'm happy to give it a try. If you're familiar with the code, help me understand it. Is model_ema on CPU? If yes, why don't they use summon_full_params(offload_to_cpu=True). If model_ema is on GPU, this only works when the whole model fits into GPU memory twice?
|
Sweet, EMA finally merged! Great work, @senarvi! |
|
Congrats <3 |
|
@amorehead based on your comment, I am going to assume this is not something that should lead to a code change but is it something that is worth documenting at least? |
|
@SkafteNicki precisely |
|
Hi @amorehead, I came across your comment and I am now concerned my weights are not correctly averaged when I decide to train for a few more epochs by increasing |
|
Hi, @philgzl. Yes, I've observed this behavior myself. However, I'm not sure it can be considered a bug, since reaching |
|
Thanks. So if I understand correctly this is all because the non-averaged weights are not saved at the end of training. This seems like a quite drastic behavior for just reducing the checkpoint size. Would it make sense to at least add an option to save the non-averaged weights, so users can resume training or inspect the non-averaged weights? I would personally use that option by default in all my projects. Just to be clear, both the averaged and non-averaged weights are saved during training such that training can be correctly resumed in case of a crash, right? What about the interaction with a |
I think so, perhaps setting the default to not save them to save space.
Yes, you can tell if you compare the size of your checkpoints mid-training and at the end of training. The latter should be smaller because it only contains the averaged weights.
No, mid-training checkpoints contain both the averaged weights (n.b., stored in the default model
The averaged weights will be loaded by default via the model checkpoint's |
|
Great, thanks a lot for the clear explanation @amorehead! |

A callback that updates an AveragedModel after every training step
What does this PR do?
This is similar to the existing StochasticWeightAveraging callback, but wraps the AveragedModel class from PyTorch. Reduced code duplication means easier maintenance. Also, any averaging function can be used. By default, the callback does averaging on every step, but this can be customized by overriding the
should_update(step_idx, epoch_idx)method.Fixes #10914
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20545.org.readthedocs.build/en/20545/