Skip to content

[Callbacks] TST Add tests to check hook calls#33394

Merged
lesteve merged 6 commits intoscikit-learn:callbacksfrom
jeremiedbb:tst-record-hooks
Mar 6, 2026
Merged

[Callbacks] TST Add tests to check hook calls#33394
lesteve merged 6 commits intoscikit-learn:callbacksfrom
jeremiedbb:tst-record-hooks

Conversation

@jeremiedbb
Copy link
Copy Markdown
Member

Towards #33324

Enhanced a testing callback to record the hook calls and added tests to check the count of callback hooks.
A next step could be to check that the hook calls matches the context tree but let's keep this PR simple.

ping @FrancoisPgm @StefanieSenger

@jeremiedbb jeremiedbb added No Changelog Needed Quick Review For PRs that are quick to review Callbacks labels Feb 24, 2026
Comment on lines +305 to +306
self._has_called_on_fit_begin = True

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved that before calling on_fit_begin in case on_fit_begin itself raises.

# TODO(callbacks): Figure out the exact behavior we want when cloning an
# estimator with callbacks.
new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False)
new_object._skl_callbacks = estimator._skl_callbacks
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proper behavior for clone is worked on in #33340. This change is temporary for this PR to pass because otherwise clone makes a deep copy of the callback which breaks the Manager purpose.

@jeremiedbb jeremiedbb moved this to In progress in Labs Feb 26, 2026
@jeremiedbb jeremiedbb added this to Labs Feb 26, 2026
@lesteve
Copy link
Copy Markdown
Member

lesteve commented Feb 27, 2026

For reference, this seems to tackle the comments in #28760 (comment) below.

This looks good, two things we could do in another PRs:

Hacky code from a debug session on test_meta_estimator_autopropagated_callback_hooks_called

(Pdb) pprint([(hn, c.task_name if c is not None else None, c.task_id if c is not None else None) for hn, c, arg in callback.record])
[('on_fit_begin', None, None),
 ('on_fit_task_end', '', 0),
 ('on_fit_task_end', '', 1),
 ('on_fit_task_end', '', 2),
 ('on_fit_task_end', '', 3),
 ('on_fit_task_end', '', 4),
 ('on_fit_task_end', '', 5),
 ('on_fit_task_end', '', 6),
 ('on_fit_task_end', '', 7),
 ('on_fit_task_end', '', 8),
 ('on_fit_task_end', '', 9),
 ('on_fit_task_end', 'inner', 0),
 ('on_fit_task_end', '', 0),
 ('on_fit_task_end', '', 1),
 ('on_fit_task_end', '', 2),
 ('on_fit_task_end', '', 3),
 ('on_fit_task_end', '', 4),
 ('on_fit_task_end', '', 5),
 ('on_fit_task_end', '', 6),
 ('on_fit_task_end', '', 7),
 ('on_fit_task_end', '', 8),
 ('on_fit_task_end', '', 9),
 ('on_fit_task_end', 'inner', 1),
 ('on_fit_task_end', '', 0),
 ('on_fit_task_end', '', 1),
 ('on_fit_task_end', '', 2),
 ('on_fit_task_end', '', 3),
 ('on_fit_task_end', '', 4),
 ('on_fit_task_end', '', 5),
 ('on_fit_task_end', '', 6),
 ('on_fit_task_end', '', 7),
 ('on_fit_task_end', '', 8),
 ('on_fit_task_end', '', 9),
 ('on_fit_task_end', 'inner', 2),
 ('on_fit_task_end', 'outer', 0),
 ('on_fit_task_end', '', 0),
 ('on_fit_task_end', '', 1),
 ('on_fit_task_end', '', 2),
 ('on_fit_task_end', '', 3),
 ('on_fit_task_end', '', 4),
 ('on_fit_task_end', '', 5),
 ('on_fit_task_end', '', 6),
 ('on_fit_task_end', '', 7),
 ('on_fit_task_end', '', 8),
 ('on_fit_task_end', '', 9),
 ('on_fit_task_end', 'inner', 0),
 ('on_fit_task_end', '', 0),
 ('on_fit_task_end', '', 1),
 ('on_fit_task_end', '', 2),
 ('on_fit_task_end', '', 3),
 ('on_fit_task_end', '', 4),
 ('on_fit_task_end', '', 5),
 ('on_fit_task_end', '', 6),
 ('on_fit_task_end', '', 7),
 ('on_fit_task_end', '', 8),
 ('on_fit_task_end', '', 9),
 ('on_fit_task_end', 'inner', 1),
 ('on_fit_task_end', '', 0),
 ('on_fit_task_end', '', 1),
 ('on_fit_task_end', '', 2),
 ('on_fit_task_end', '', 3),
 ('on_fit_task_end', '', 4),
 ('on_fit_task_end', '', 5),
 ('on_fit_task_end', '', 6),
 ('on_fit_task_end', '', 7),
 ('on_fit_task_end', '', 8),
 ('on_fit_task_end', '', 9),
 ('on_fit_task_end', 'inner', 2),
 ('on_fit_task_end', 'outer', 1),
 ('on_fit_task_end', '', 0),
.
.
.

@jeremiedbb
Copy link
Copy Markdown
Member Author

I sync'ed with callbacks and made the testing callback use the global manager

estimator failing at fit in the estimator itself in #28760 (comment). Functionality is probably there, but maybe not a test?

There's a test for an error raised in each hook. It's not by the estimator itself but in fit still. So technically if the estimator calls on_fit_task_end and on_fit_task_end is not different from an explicit raise in the body of fit. What do you think ?

we could try to do another PR to check the content more thoroughly than only counting the number of hooks.

Yes, the tests here are basic tests but we can definitely keep improving them in subsequent PRs.
Related to #33338 in particular.

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change in the testing infra LGTM but it reveal something that I overlooked in the current design of the callback infra w.r.t. the calls to on_fit_begin (and on_fit_end) for auto-propagated callbacks. See the points below.

Maybe @FrancoisPgm you know the answer?


def on_fit_begin(self, estimator):
pass
self.record.append(("on_fit_begin", None, None))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also record the estimator object or even deep copy of it to snapshot its internal state at the time the hook is called. This would mean that the estimator instance should be picklable at any time during fit, but that shouldn't be a problem for most scikit-learn estimators, wouldn't it? If it is, we could make this (the recording of estimators) and option in the constructor of the callback to be enabled only in specific tests with specific kinds of estimators where we know that this wouldn't cause a problem.

This could be both useful to inspect the state of the estimator at the time of the hook call, or to distinguish hook calls on different estimators in case of auto-propagated callbacks.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the estimator in the record. It's not used for now, but we can use it in a subsequent PR to add tests, e.g. check that the estimator received by the hook is the current estimator, not always the meta-estimator.

option in the constructor of the callback to be enabled only in specific tests with specific kinds of estimators where we know that this wouldn't cause a problem.

This part can be done later if/when we use this testing callback for common tests on sklearn estimators

n_jobs=n_jobs,
).set_callbacks(callback).fit()

assert callback.count_hooks("on_fit_begin") == 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't the on_fit_begin hook of an auto-propagated callback called when entering the fit method of each MaxIterEstimator clone created by the MetaEstimator in its own fit call? I checked the source code of the context class and this seems intentional, but I don't understand the motivation behind this design choice.

I would have expected the those hook calls so also show up in the propagated callback record and to observe 1 + n_outer * n_inner total calls to on_fit_begin in the record.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we leave the auto-propagated callbacks themselves ignore nested on_fit_begin if they wish?

I guess to do so, we would need to pass an extra context argument to on_fit_begin so that the callback can inspect the presence of a parent attribute on the context to make this decision.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't the on_fit_begin hook of an auto-propagated callback called when entering the fit method of each MaxIterEstimator clone created by the MetaEstimator in its own fit call? I checked the source code of the context class and this seems intentional, but I don't understand the motivation behind this design choice.

Yes it's definitely an intentional design choice to have on_fit_begin and on_fit_end be called only in the root estimator for auto-propagated callbacks. I think the main motivation is that these hooks take care of the set-up and tear-down of the callback for the task, and since the callback is the same (or duplicates) at all nested levels, it is enough to do it once. For example, for the ProgressBar these two hooks serves to spawn and join the thread for the progressbar display, so no need to do it at each nested level. Maybe @jeremiedbb has in mind other motivations that I don't see.

Shouldn't we leave the auto-propagated callbacks themselves ignore nested on_fit_begin if they wish?

I like the idea of having a more permissive framework. We might realize for future auto-propagated callbacks that it'd be more practical to call on_fit_begin and on_fit_end at each level, but I don't have a specific case in mind. I wonder if we should we bother to change it now or wait that we actually need it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's definitely an intentional design choice to have on_fit_begin and on_fit_end be called only in the root estimator for auto-propagated callbacks. I think the main motivation is that these hooks take care of the set-up and tear-down of the callback for the task

It's exactly that.

on_fit_begin and on_fit_end are meant for setup and tear down. An auto-propagated callback is a callback set on the outermost estimator, so should be setup and teared down on it. Propagating it to sub-estimators doesn't mean set it on sub-estimators (i.e. tied to the sub-estimator's context tree) but rather extend the context tree to sub-estimators such that the on_fit_task_end hook is called all the way down (which gives a totally different picture).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, during an IRL meeting, we decide that @jeremiedbb would open a concurrent PR to refactor the protocol to make callback setup / teardown methods explicit and decoupled from on_fit_begin / on_fit_end.

Let's move that design discussion there. Meanwhile, I think we can proceed with the review and merge of this PR independently and update the tests accordingly once the design evolves.

@lesteve
Copy link
Copy Markdown
Member

lesteve commented Mar 4, 2026

One question maybe for a further PR: what should happen if there are two callbacks and the first callback fails in on_fit_end? My naive expectation would be that second callback's on_fit_end is still called but that doesn't seem to be the case right now (eval_on_fit_begin/end is essentially a loop calling on_fit_begin/end for each callback so it stops at the first one that fails).

Here is a test to make my expectation more precise (variation on an existing test but with two callbacks), which fails for on_fit_begin and on_fit_end.

@pytest.mark.parametrize("fail_at", ["on_fit_begin", "on_fit_task_end", "on_fit_end"])
def test_callback_error_with_two_callbacks(fail_at):
    """Check that a failing callback is properly teared down."""
    callbacks = [FailingCallback(fail_at=fail_at), TestingCallback()]

    estimator = MaxIterEstimator().set_callbacks(callbacks)
    with pytest.raises(ValueError, match=f"Failing callback failed at {fail_at}"):
        estimator.fit()

    for cb in callbacks:
        assert cb.count_hooks("on_fit_begin") == 1
        assert cb.count_hooks("on_fit_end") == 1

@jeremiedbb
Copy link
Copy Markdown
Member Author

good point, we can call each on_fit_end in a try except and reraise afterwards if an error occurred.

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for merging as it is and refine the tests later with subsequent PRs that evolve the protocol design.

@lesteve lesteve merged commit ab619bc into scikit-learn:callbacks Mar 6, 2026
38 of 39 checks passed
@github-project-automation github-project-automation bot moved this from In progress to Done in Labs Mar 6, 2026
@lesteve
Copy link
Copy Markdown
Member

lesteve commented Mar 6, 2026

Let's merge this one then!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Callbacks No Changelog Needed Quick Review For PRs that are quick to review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants