Skip to content

Add early stopping callback to pytorch trainer#8581

Merged
sgugger merged 20 commits intohuggingface:masterfrom
cbrochtrup:early-stopping-patience
Nov 23, 2020
Merged

Add early stopping callback to pytorch trainer#8581
sgugger merged 20 commits intohuggingface:masterfrom
cbrochtrup:early-stopping-patience

Conversation

@cbrochtrup
Copy link
Copy Markdown

@cbrochtrup cbrochtrup commented Nov 17, 2020

Summary

Address PyTorch half of #4894 by adding early stopping patience and a minimum threshold metrics must improve to prevent early stopping. I piggybacked heavily off of #7431 since the two functions are very similar.

Since #4186 seems to be abandoned and behind master, I figured I'd take a crack at this.

Who can review?

Anyone! But @julien-c and @sgugger seem the most appropriate.

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Nov 17, 2020

Hi there. Thanks your PR! When I was designing the callbacks, it was to be them small independent pieces of code. I would prefer if early stopping had its own callback that the user would then choose to add or not. Do you think you could amend your PR in that direction?

@cbrochtrup
Copy link
Copy Markdown
Author

cbrochtrup commented Nov 17, 2020

Hello, thank you for your feedback! I will amend the PR in that direction.

Could you clarify which pieces of early stopping should be in TrainerState and which should be in the callback? I'm grappling with the similarities between best_model_checkpoint and early stopping attributes.

class EarlyStoppingCallback(TrainerCallback):
    best_metric: Optional[float] = None # maybe not this
    best_model_checkpoint: Optional[str] = None # maybe not this either
    early_stopping_patience: int = None
    early_stopping_patience_counter: int = None

    def on_evaluate(self, args, state, control, **kwargs):
        # Keep track of patience
        # End training via early stopping
        if (
            self.early_stopping_patience is not None
            and self.early_sotpping_patience_counter >= self.early_stopping_patience
        ):
            control.should_training_stop = True

@cbrochtrup
Copy link
Copy Markdown
Author

Or do you mean I just move the if statement I added to its own callback and keep TrainerState as is?

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Nov 17, 2020

The TrainerState shouldn't change, so the callback you are writing above sounds fine, without the arguments marked with # maybe not this, which should already be in the TrainerState, I think.
Does that sound right to you?

@cbrochtrup
Copy link
Copy Markdown
Author

That makes sense. I think this block of code (to line 933) could be a callback because it's all about the best metric. Then users could customize the best model calculations. Is that desirable?

If you think that's out of scope I'll keep the early stopping callback simple and separate from the best metric calculation.

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Nov 17, 2020

I had put it in Trainer because I thought multiple callbacks could need it and it's used by load_best_model_at_end which is kind of a core feature.

@cbrochtrup
Copy link
Copy Markdown
Author

Sounds good, you know best! I keep load_best_model_at_end in the Trainer and push up an early stopping callback sometime this week.

@cbrochtrup cbrochtrup changed the title Add early stopping patience to pytorch trainer Add early stopping callback to pytorch trainer Nov 19, 2020
Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few mote things to change, but we're close to get this in good state. Thanks a lot for your work on this!

metric_value = metrics.get(metric_to_check)

if metric_value is None:
logger.warning(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good warning!

self.early_stopping_patience_counter += 1

def on_train_begin(self, args, state, control, **kwargs):
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand why this line is necessary? I feel we should be able to use this callback without the option load_best_model_at_end? The other sanity checks are perfectly ok.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary because we require control.should_save=True for _save_checkpoint to update the best metric. Should I move the best metric calculation into its own function and place it in the should_evaluate block?

Copy link
Copy Markdown
Author

@cbrochtrup cbrochtrup Nov 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it's not fully intuitive to need load_best_model_at_end, but it makes sense to me because if we don't load the best model early stopping will stop us, but the model we receive back from training will not be the model early stopping deemed best.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's leave it as is for now then, and we will re-evaluate if some users complain!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saw this issue while debugging something. It doesn't seem intuitive how these two are related, so can we please do what @cbrochtrup suggested above?

@sgugger sgugger requested a review from LysandreJik November 20, 2020 16:13
@cbrochtrup
Copy link
Copy Markdown
Author

Thanks for your thorough and affable review!

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition, LGTM!

@sgugger sgugger merged commit 8ffc01a into huggingface:master Nov 23, 2020
@cbrochtrup cbrochtrup deleted the early-stopping-patience branch November 23, 2020 23:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants