Skip to content

Generic weight averaging callback that supports EMA#20545

Merged
Borda merged 25 commits intoLightning-AI:masterfrom
senarvi:generic-weight-averaging
Aug 15, 2025
Merged

Generic weight averaging callback that supports EMA#20545
Borda merged 25 commits intoLightning-AI:masterfrom
senarvi:generic-weight-averaging

Conversation

@senarvi
Copy link
Copy Markdown
Contributor

@senarvi senarvi commented Jan 14, 2025

A callback that updates an AveragedModel after every training step

What does this PR do?

This is similar to the existing StochasticWeightAveraging callback, but wraps the AveragedModel class from PyTorch. Reduced code duplication means easier maintenance. Also, any averaging function can be used. By default, the callback does averaging on every step, but this can be customized by overriding the should_update(step_idx, epoch_idx) method.

Fixes #10914

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs) => Discussed in issue Add feature Exponential Moving Average (EMA) #10914
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request? => There are none.
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--20545.org.readthedocs.build/en/20545/

@github-actions github-actions bot added docs Documentation related pl Generic label for PyTorch Lightning package labels Jan 14, 2025
@lantiga
Copy link
Copy Markdown
Contributor

lantiga commented Jan 14, 2025

Hey @senarvi, this looks great!

I saw you already added support for saving and resuming which is great. There are many scenarios there (save every n steps, time-based, every epoch, etc) let's make sure we cover them all (for inspiration, we added quite a few tests here #20379)

we could still have different callbacks ("StepwiseAveragingCallback" and "EpochwiseAveragingCallback")

No I think it's better to have one with configurable averaging flags, more lightning-esque

Constructs the AveragedModel with use_buffers=True, so that an extra step is not needed for updating the batch normalization statistics. StochasticWeightAveraging performs an extra step in the end. Consequently the implementation is significantly more complex and it's difficult to make sure that it works in all cases. Should we add this as an option in this class too?

I think this is ok, but my doubt with forcing use_buffers to be true is what happens when a user has a module with buffers in it that are not meant to be averaged. I guess at that point they will probably be the same over time (e.g. the RoPE cache), but that's not really a guarantee.

Wdyt about this? I don't necessarily want to make the implementation more complex, so this is just for discussion.

Updates the average model after every step. StochasticWeightAveraging updates the average model after every epoch, and I recall that the original paper updated it only at certain points (the learning rate minima). I guess it would be nice to be able to select whether the average model will be updated after every step, after every epoch, or after certain epochs. Then we would need only one callback and we could remove the StochasticWeightAveraging callback, but would it make this class too complex?

It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume.

Regarding removing the StochasticWeightAveraging callback, I don't necessarily see that happening. We have a pretty strong commitment to backward compatibility at this point, so keeping that in with a notice to just use this one will not hurt.

@senarvi
Copy link
Copy Markdown
Contributor Author

senarvi commented Jan 15, 2025

I think this is ok, but my doubt with forcing use_buffers to be true is what happens when a user has a module with buffers in it that are not meant to be averaged. I guess at that point they will probably be the same over time (e.g. the RoPE cache), but that's not really a guarantee.

That's a good point. I don't know what would be a good solution.

Updates the average model after every step. StochasticWeightAveraging updates the average model after every epoch, and I recall that the original paper updated it only at certain points (the learning rate minima). I guess it would be nice to be able to select whether the average model will be updated after every step, after every epoch, or after certain epochs. Then we would need only one callback and we could remove the StochasticWeightAveraging callback, but would it make this class too complex?

It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume.

That's an interesting idea. We could have the user pass a function update_on_step(global_step) or update_on_epoch(epoch) that returns a boolean. After each optimizer step and after each epoch we would call the function to check whether we should update the average model.

It seems that AveragedModel will copy the current model parameters when called the first time, and update the average on subsequent calls. This means that the first average is computed when update_on_step() or update_on_epoch() returns True for the second time. I don't see a better alternative.

I checked how StochasticWeightAveraging does this and I think it doesn't work correctly. It only ever updates the average model parameters in on_train_epoch_start(), so the average is not updated after the last epoch. Just shows why I'd like to keep the logic as simple as possible.

@scurkovic
Copy link
Copy Markdown

Hi, I have a couple questions.

  1. You added the on_validation_epoch_start and on_validation_epoch_end hooks to swap the weights, but shouldn't the same happen for test?
  2. In my current workflow I have a separate script that does the model exporting to ONNX. It's short, and really the only Lightning specific thing is the MyLightningModule.load_from_checkpoint(...) method. Since the averaged weights are a part of the callback, I would have to instantiate the trainer for the weights to be loaded. And even then, I wouldn't have a function I could call to explicitly swap the weights (since _swap_weights is private and not really accessible). So, my question is, can we have some sort of an API, outside of the trainer, that can load the averaged weights instead of the regular weights? Perhaps adding some sort of a parameter to the load_from_checkpoint method?

@senarvi
Copy link
Copy Markdown
Contributor Author

senarvi commented Jan 16, 2025

Hi @cyanic-selkie

During training (stage=fit), the actual LightningModule is what we update using the optimizer (I call it the current model) and an AveragedModel is maintained in the background (I call it the average model).

I assume that validation is only called during training. Before and after validation we swap the current model and the average model, so the average model will be validated.

When saving a checkpoint, we save the average model parameters in the state_dict. So if you later load the checkpoint without WeightAveraging callback and run a test or export to ONNX, you will be using the average parameters.

When training ends, we copy the average model parameters to the current model. So if you run a test or export to ONNX after training, you will be using the average parameters.

That's the idea at least. I'm not confident that I have thought about every possible corner case. It would be great if you could test that it works in your case.

@scurkovic
Copy link
Copy Markdown

@senarvi Ah! Thanks for the clarification, I should've checked the code out more carefully. I tried your branch out on a quantization aware training enabled model with ONNX export at the end and everything is working beautifully! I hope this gets merged quickly.

@senarvi senarvi force-pushed the generic-weight-averaging branch from efc77dc to 0010492 Compare January 23, 2025 16:07
@senarvi
Copy link
Copy Markdown
Contributor Author

senarvi commented Jan 23, 2025

The user can now provide either the update_on_step or the update_on_epoch argument. (In theory also both.) It should be a function that takes the step/epoch number and returns True if the average model should be updated at that point of time.

For example:

update_on_step = lambda x: x > 100

or

update_on_epoch = lambda x: x in (3, 5, 7)

Using update_on_epoch, SWA should be possible. I added one unit test for SWA.

I tested EMA in an actual learning task and it gave an improvement, so I'm starting to be more confident that this works.

I think the biggest question that is still left is whether it's a problem that we force use_buffers=True. It would be nice if we could provide the option to instead call update_bn() after training and we wouldn't have to duplicate any of that code. That function takes a data loader and iterates through the data. I can imagine that passing the Trainer's data loader might not work in all cases. We could also leave calling this function to the user.

StochasticWeightAveraging increments the number of epochs in on_fit_start() and during the extra epoch disables the backward pass. I could also copy the code from that class, but there are some details that I don't understand, and I'm not that excited of copying code that I don't fully understand.

@tchaton I think you contributed the StochasticWeightAveraging callback, maybe you have some insight?

* A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs.
* The user can provide a callback that defines after which steps or epochs the average model is updated.
@senarvi senarvi force-pushed the generic-weight-averaging branch from 5f34205 to c8d50bd Compare January 23, 2025 18:00
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Jan 23, 2025
@scurkovic
Copy link
Copy Markdown

Is there anything blocking this from being merged?

@senarvi senarvi changed the title Generic weight averaging callback that supports EMA [wip] Generic weight averaging callback that supports EMA Feb 2, 2025
@senarvi senarvi marked this pull request as ready for review February 2, 2025 21:21
@senarvi
Copy link
Copy Markdown
Contributor Author

senarvi commented Feb 2, 2025

I marked this ready for review. There were no comments whether it's a problem that we force use_buffers=True. Would it make sense to merge this now and perhaps introduce such option later based on the feedback that we receive?

@codecov
Copy link
Copy Markdown

codecov bot commented Feb 2, 2025

Codecov Report

Attention: Patch coverage is 94.68085% with 5 lines in your changes missing coverage. Please review.

Project coverage is 79%. Comparing base (831870a) to head (5deb0bb).

❗ There is a different number of reports uploaded between BASE (831870a) and HEAD (5deb0bb). Click for more details.

HEAD has 349 uploads less than BASE
Flag BASE (831870a) HEAD (5deb0bb)
cpu 105 27
python3.10 24 6
lightning_fabric 26 0
pytest 57 0
python 12 3
python3.12 10 3
python3.12.7 35 9
lightning 60 15
python3.11 24 6
gpu 4 0
pytorch2.1 12 6
pytorch_lightning 23 12
pytest-full 52 27
pytorch2.2.2 6 3
pytorch2.3 6 3
pytorch2.5 6 3
pytorch2.6 6 3
pytorch2.4.1 6 3
pytorch2.5.1 5 3
pytorch2.7 5 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #20545     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         268      266      -2     
  Lines       23449    23488     +39     
=========================================
- Hits        20389    18475   -1914     
- Misses       3060     5013   +1953     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid contribution @senarvi! I added a few comments (most are quick to address, let me know what you can do here vs follow up PR), but overall looks great.

@lantiga
Copy link
Copy Markdown
Contributor

lantiga commented Feb 3, 2025

BTW: I think it's totally fine to merge this as is and open an issue to gather discussions about averaging buffers.

The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. cpu) to keep the callback usable with larger models.

@senarvi
Copy link
Copy Markdown
Contributor Author

senarvi commented Feb 4, 2025

The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. cpu) to keep the callback usable with larger models.

There's a device argument already, and actually the default is cpu - as with StochasticWeightAveraging.

@senarvi senarvi force-pushed the generic-weight-averaging branch from c6856eb to 42d91cd Compare February 10, 2025 10:30
@h2o64
Copy link
Copy Markdown

h2o64 commented Feb 21, 2025

Hi! Thanks for this great PR. The current implementation only leverages avg_fn argument should it also consider the in-place version multi_avg_fn ?

@senarvi
Copy link
Copy Markdown
Contributor Author

senarvi commented Feb 21, 2025

Hi! Thanks for this great PR. The current implementation only leverages avg_fn argument should it also consider the in-place version multi_avg_fn ?

I think we could just pass **averaged_model_kwargs. I'll look into it over the weekend.

- The user can specify when to update the average model by overriding the should_update() method
- Any keyword arguments will be passed to the AveragedModel constructor
@SkafteNicki
Copy link
Copy Markdown
Collaborator

@Borda should be good now

if is_overridden("configure_model", pl_module):
rank_zero_warn(
"You're using the WeightAveraging callback with a model that overrides the configure_model "
"callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What changes do we need to support sharded models like when using deepspeed/FSDP?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets do it in follow-up

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any update on this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@npuichigo this pull request has been merged. If you're interested in FSDP, you should open a new issue. I think the average model parameters should be sharded too, and then gathered in the end for saving to disk. Someone who has developed the FSDP code or at least uses it, knows the best way for doing that and can test how this affects the performance.

Copy link
Copy Markdown
Contributor

@amorehead amorehead Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@senarvi, I believe SimpleFold was trained with Lightning + FSDP using PyTorch's AveragedModel wrapper, with all_gathers of parameters handled here: https://github.com/apple/ml-simplefold/blob/ff4b91daca2ef8cafe83e3e80140bcce6a3136d1/src/simplefold/model/simplefold.py#L718. Would you be interested in making a small new PR to support FSDP like this?

Copy link
Copy Markdown
Contributor Author

@senarvi senarvi Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @amorehead . That solution seems simple and I'm happy to give it a try. If you're familiar with the code, help me understand it. Is model_ema on CPU? If yes, why don't they use summon_full_params(offload_to_cpu=True). If model_ema is on GPU, this only works when the whole model fits into GPU memory twice?

@Borda Borda merged commit 1ec459f into Lightning-AI:master Aug 15, 2025
84 checks passed
@amorehead
Copy link
Copy Markdown
Contributor

Sweet, EMA finally merged! Great work, @senarvi!

@catalpaaa
Copy link
Copy Markdown

Congrats <3

@amorehead
Copy link
Copy Markdown
Contributor

amorehead commented Oct 18, 2025

One small caveat for folks to be aware of when using this callback: If your LightningModule reaches the end of your test epoch and then calls on_train_end(), the callback will swap your current (unaveraged) model weights with your EMA (averaged) model weights prior to (potentially) saving a final model checkpoint. If such a checkpoint is saved, it will contain your averaged model weights, not the unaveraged weights. This is fine for most downstream use cases, assuming the model is truly finished training. However, if you later decide to load such a checkpoint and finetune the model for additional training epochs, your training losses will likely look fine, but your validation/test performance might suffer.

I've seen this myself, where loading a checkpoint after it has reached its trainer.max_epochs=20000 and then manually training for longer (after setting a larger value of trainer.max_epochs=40000) resulted in normal training losses with no disruption, yet I noticed a huge drop in validation performance as shown in the image below. This is likely because, from then on, the callback is treating the averaged weights as the unaveraged model weights.
image

@SkafteNicki
Copy link
Copy Markdown
Collaborator

@amorehead based on your comment, I am going to assume this is not something that should lead to a code change but is it something that is worth documenting at least?

@amorehead
Copy link
Copy Markdown
Contributor

@SkafteNicki precisely

@philgzl
Copy link
Copy Markdown

philgzl commented Dec 23, 2025

Hi @amorehead, I came across your comment and I am now concerned my weights are not correctly averaged when I decide to train for a few more epochs by increasing max_epochs. Are you saying that the averaged and non-averaged weights are not correctly loaded from the checkpoint when resuming training like this? If so isn't that something which needs fixing?

@amorehead
Copy link
Copy Markdown
Contributor

amorehead commented Dec 23, 2025

Hi, @philgzl. Yes, I've observed this behavior myself. However, I'm not sure it can be considered a bug, since reaching trainer.max_epochs should indicate to the callback that training has finished and that it can save the "final" (irreversible) checkpoint for production/inference. If you exceed the original max_epochs, you are then using the callback in an undefined manner.

@philgzl
Copy link
Copy Markdown

philgzl commented Dec 23, 2025

Thanks. So if I understand correctly this is all because the non-averaged weights are not saved at the end of training.

This seems like a quite drastic behavior for just reducing the checkpoint size. Would it make sense to at least add an option to save the non-averaged weights, so users can resume training or inspect the non-averaged weights? I would personally use that option by default in all my projects.

Just to be clear, both the averaged and non-averaged weights are saved during training such that training can be correctly resumed in case of a crash, right?

What about the interaction with a ModelCheckpoint which tracks e.g. validation loss? Is a best checkpoint saved mid-training considered ready for "production", as in it does not contain the non-averaged weights? Or do I need to manually to load the averaged weights after loading such a checkpoint?

@amorehead
Copy link
Copy Markdown
Contributor

amorehead commented Dec 23, 2025

Would it make sense to at least add an option to save the non-averaged weights

I think so, perhaps setting the default to not save them to save space.

Both the averaged and non-averaged weights are saved during training such that training can be correctly resumed in case of a crash, right?

Yes, you can tell if you compare the size of your checkpoints mid-training and at the end of training. The latter should be smaller because it only contains the averaged weights.

Is a best checkpoint saved mid-training considered ready for "production", as in it does not contain the non-averaged weights?

No, mid-training checkpoints contain both the averaged weights (n.b., stored in the default model state_dict key of the checkpoint) and non-averaged weights (n.b., stored in a separate, dedicated key). As such, these mid-training checkpoints will be larger than necessary for downstream deployment without manual editing. However, it would be impossible to correctly resume training without saving both copies of the weights, so this is a necessary tradeoff.

Do I need to manually load the averaged weights after loading such a checkpoint?

The averaged weights will be loaded by default via the model checkpoint's state_dict key (n.b., which PyTorch assumes are your non-averaged, regular weights). That's why it's necessary when resuming training to reuse the same WeightAveragingCallback, because it will run logic to correctly extract both the averaged and non-averaged weights from your model checkpoints when resuming. Otherwise, if you directly run torch.load_state_dict (or Lightning's equivalent) on your checkpoints, you will be loading the averaged weights.

@philgzl
Copy link
Copy Markdown

philgzl commented Dec 23, 2025

Great, thanks a lot for the clear explanation @amorehead!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add feature Exponential Moving Average (EMA)