Skip to content

Add automatic best model loading to Trainer#7431

Merged
sgugger merged 3 commits intomasterfrom
trainer_best_model
Sep 29, 2020
Merged

Add automatic best model loading to Trainer#7431
sgugger merged 3 commits intomasterfrom
trainer_best_model

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Sep 28, 2020

What does this PR do?

This PR cleans up a bit the part that saves the training state inside Trainer and adds an API that can track which was the best model during any of the evaluation phases to load it back at the end.

When fine-tuning a model on a dataset that can easily overfit the model, it's quite common to have the last model not be the best one (in terms of metrics). This PR adds a TrainingArgument named load_best_model_at_end that triggers the following behavior:

  • save_steps gets ignored and the model is saved every time there is an evaluation (determined by evaluation_strategy and eval_steps)
  • It keeps track in a TrainerState of when the best model was encountered (that state is saved along the checkpoints so it can work with resuming a training)
  • The best model is determined by the new TrainingArguments metric_for_best_model (defaults to the loss) and greater_is_better (default to False for the loss, True otherwise).
  • The best model is loaded once the training is finished.

In passing I've added some tests of the saving API in Trainer and made sure it can handle both PreTrainedModel and regular nn.Module (a feature asked in #6901). Both are now tested in the CI, as is the new API.

Fixes #6901

Those newly introduced arguments and APIs can then be leveraged to have early stopping supported in Trainer.

@codecov
Copy link

codecov bot commented Sep 28, 2020

Codecov Report

Merging #7431 into master will increase coverage by 0.53%.
The diff coverage is 76.74%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #7431      +/-   ##
==========================================
+ Coverage   78.17%   78.71%   +0.53%     
==========================================
  Files         181      181              
  Lines       35800    35858      +58     
==========================================
+ Hits        27986    28224     +238     
+ Misses       7814     7634     -180     
Impacted Files Coverage Δ
src/transformers/trainer.py 63.23% <73.43%> (+7.80%) ⬆️
src/transformers/trainer_utils.py 63.30% <80.00%> (+2.66%) ⬆️
src/transformers/training_args.py 91.72% <100.00%> (+0.45%) ⬆️
src/transformers/modeling_tf_electra.py 24.25% <0.00%> (-73.56%) ⬇️
src/transformers/modeling_tf_lxmert.py 22.14% <0.00%> (-72.41%) ⬇️
src/transformers/modeling_rag.py 25.32% <0.00%> (-51.72%) ⬇️
src/transformers/modeling_tf_xlm.py 58.52% <0.00%> (-34.74%) ⬇️
src/transformers/tokenization_t5.py 61.53% <0.00%> (-33.66%) ⬇️
src/transformers/modeling_marian.py 60.00% <0.00%> (-30.00%) ⬇️
src/transformers/modeling_longformer.py 74.14% <0.00%> (-18.70%) ⬇️
... and 25 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f62f2ff...738935a. Read the comment docs.

checkpoints_sorted[-1],
checkpoints_sorted[best_model_index],
)
print(checkpoints_sorted)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bogus print

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oopsie, leftover from debug.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Cool that we allow other nn.Modules than PreTrainedModels now!

Comment on lines +148 to +168
load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to load the best model found during training at the end of training.

.. note::

When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
after each evaluation.
metric_for_best_model (:obj:`str`, `optional`)
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
Will default to :obj:`"loss"` if unspecified and :obj:`load_best_model_at_end=True` (to use the evaluation
loss).

If you set this value, :obj:`greater_is_better` will defaut to :obj:`True`. Don't forget to set it to
:obj:`False` if your metric is better when lower.
greater_is_better (:obj:`bool`, `optional`)
Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better
models should have a greater metric or not. Will default to:

- :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or
:obj:`"eval_loss"`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clear doc!

@sgugger sgugger merged commit 52e8392 into master Sep 29, 2020
@sgugger sgugger deleted the trainer_best_model branch September 29, 2020 14:41
@PhilipMay
Copy link
Contributor

IMO this closes #4186

@PhilipMay
Copy link
Contributor

PhilipMay commented Oct 3, 2020

@sgugger how does this work together with save_total_limit ? If it is set might it happen that the best model gets deleted?

well - see here #7556

@sgugger
Copy link
Collaborator Author

sgugger commented Oct 4, 2020

The best model is not deleted with save_total_limit. It is always put at the top of the list after sorting the chceckpoints.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Relaxing PreTrainedModel requirement in _save

4 participants