[Callbacks] TST Add tests to check hook calls#33394
[Callbacks] TST Add tests to check hook calls#33394lesteve merged 6 commits intoscikit-learn:callbacksfrom
Conversation
| self._has_called_on_fit_begin = True | ||
|
|
There was a problem hiding this comment.
moved that before calling on_fit_begin in case on_fit_begin itself raises.
| # TODO(callbacks): Figure out the exact behavior we want when cloning an | ||
| # estimator with callbacks. | ||
| new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False) | ||
| new_object._skl_callbacks = estimator._skl_callbacks |
There was a problem hiding this comment.
The proper behavior for clone is worked on in #33340. This change is temporary for this PR to pass because otherwise clone makes a deep copy of the callback which breaks the Manager purpose.
|
For reference, this seems to tackle the comments in #28760 (comment) below. This looks good, two things we could do in another PRs:
Hacky code from a debug session on (Pdb) pprint([(hn, c.task_name if c is not None else None, c.task_id if c is not None else None) for hn, c, arg in callback.record])
[('on_fit_begin', None, None),
('on_fit_task_end', '', 0),
('on_fit_task_end', '', 1),
('on_fit_task_end', '', 2),
('on_fit_task_end', '', 3),
('on_fit_task_end', '', 4),
('on_fit_task_end', '', 5),
('on_fit_task_end', '', 6),
('on_fit_task_end', '', 7),
('on_fit_task_end', '', 8),
('on_fit_task_end', '', 9),
('on_fit_task_end', 'inner', 0),
('on_fit_task_end', '', 0),
('on_fit_task_end', '', 1),
('on_fit_task_end', '', 2),
('on_fit_task_end', '', 3),
('on_fit_task_end', '', 4),
('on_fit_task_end', '', 5),
('on_fit_task_end', '', 6),
('on_fit_task_end', '', 7),
('on_fit_task_end', '', 8),
('on_fit_task_end', '', 9),
('on_fit_task_end', 'inner', 1),
('on_fit_task_end', '', 0),
('on_fit_task_end', '', 1),
('on_fit_task_end', '', 2),
('on_fit_task_end', '', 3),
('on_fit_task_end', '', 4),
('on_fit_task_end', '', 5),
('on_fit_task_end', '', 6),
('on_fit_task_end', '', 7),
('on_fit_task_end', '', 8),
('on_fit_task_end', '', 9),
('on_fit_task_end', 'inner', 2),
('on_fit_task_end', 'outer', 0),
('on_fit_task_end', '', 0),
('on_fit_task_end', '', 1),
('on_fit_task_end', '', 2),
('on_fit_task_end', '', 3),
('on_fit_task_end', '', 4),
('on_fit_task_end', '', 5),
('on_fit_task_end', '', 6),
('on_fit_task_end', '', 7),
('on_fit_task_end', '', 8),
('on_fit_task_end', '', 9),
('on_fit_task_end', 'inner', 0),
('on_fit_task_end', '', 0),
('on_fit_task_end', '', 1),
('on_fit_task_end', '', 2),
('on_fit_task_end', '', 3),
('on_fit_task_end', '', 4),
('on_fit_task_end', '', 5),
('on_fit_task_end', '', 6),
('on_fit_task_end', '', 7),
('on_fit_task_end', '', 8),
('on_fit_task_end', '', 9),
('on_fit_task_end', 'inner', 1),
('on_fit_task_end', '', 0),
('on_fit_task_end', '', 1),
('on_fit_task_end', '', 2),
('on_fit_task_end', '', 3),
('on_fit_task_end', '', 4),
('on_fit_task_end', '', 5),
('on_fit_task_end', '', 6),
('on_fit_task_end', '', 7),
('on_fit_task_end', '', 8),
('on_fit_task_end', '', 9),
('on_fit_task_end', 'inner', 2),
('on_fit_task_end', 'outer', 1),
('on_fit_task_end', '', 0),
.
.
. |
|
I sync'ed with
There's a test for an error raised in each hook. It's not by the estimator itself but in fit still. So technically if the estimator calls on_fit_task_end and on_fit_task_end is not different from an explicit raise in the body of fit. What do you think ?
Yes, the tests here are basic tests but we can definitely keep improving them in subsequent PRs. |
ogrisel
left a comment
There was a problem hiding this comment.
The change in the testing infra LGTM but it reveal something that I overlooked in the current design of the callback infra w.r.t. the calls to on_fit_begin (and on_fit_end) for auto-propagated callbacks. See the points below.
Maybe @FrancoisPgm you know the answer?
sklearn/callback/tests/_utils.py
Outdated
|
|
||
| def on_fit_begin(self, estimator): | ||
| pass | ||
| self.record.append(("on_fit_begin", None, None)) |
There was a problem hiding this comment.
I think we should also record the estimator object or even deep copy of it to snapshot its internal state at the time the hook is called. This would mean that the estimator instance should be picklable at any time during fit, but that shouldn't be a problem for most scikit-learn estimators, wouldn't it? If it is, we could make this (the recording of estimators) and option in the constructor of the callback to be enabled only in specific tests with specific kinds of estimators where we know that this wouldn't cause a problem.
This could be both useful to inspect the state of the estimator at the time of the hook call, or to distinguish hook calls on different estimators in case of auto-propagated callbacks.
There was a problem hiding this comment.
I added the estimator in the record. It's not used for now, but we can use it in a subsequent PR to add tests, e.g. check that the estimator received by the hook is the current estimator, not always the meta-estimator.
option in the constructor of the callback to be enabled only in specific tests with specific kinds of estimators where we know that this wouldn't cause a problem.
This part can be done later if/when we use this testing callback for common tests on sklearn estimators
| n_jobs=n_jobs, | ||
| ).set_callbacks(callback).fit() | ||
|
|
||
| assert callback.count_hooks("on_fit_begin") == 1 |
There was a problem hiding this comment.
Why isn't the on_fit_begin hook of an auto-propagated callback called when entering the fit method of each MaxIterEstimator clone created by the MetaEstimator in its own fit call? I checked the source code of the context class and this seems intentional, but I don't understand the motivation behind this design choice.
I would have expected the those hook calls so also show up in the propagated callback record and to observe 1 + n_outer * n_inner total calls to on_fit_begin in the record.
There was a problem hiding this comment.
Shouldn't we leave the auto-propagated callbacks themselves ignore nested on_fit_begin if they wish?
I guess to do so, we would need to pass an extra context argument to on_fit_begin so that the callback can inspect the presence of a parent attribute on the context to make this decision.
There was a problem hiding this comment.
Why isn't the on_fit_begin hook of an auto-propagated callback called when entering the fit method of each MaxIterEstimator clone created by the MetaEstimator in its own fit call? I checked the source code of the context class and this seems intentional, but I don't understand the motivation behind this design choice.
Yes it's definitely an intentional design choice to have on_fit_begin and on_fit_end be called only in the root estimator for auto-propagated callbacks. I think the main motivation is that these hooks take care of the set-up and tear-down of the callback for the task, and since the callback is the same (or duplicates) at all nested levels, it is enough to do it once. For example, for the ProgressBar these two hooks serves to spawn and join the thread for the progressbar display, so no need to do it at each nested level. Maybe @jeremiedbb has in mind other motivations that I don't see.
Shouldn't we leave the auto-propagated callbacks themselves ignore nested on_fit_begin if they wish?
I like the idea of having a more permissive framework. We might realize for future auto-propagated callbacks that it'd be more practical to call on_fit_begin and on_fit_end at each level, but I don't have a specific case in mind. I wonder if we should we bother to change it now or wait that we actually need it.
There was a problem hiding this comment.
Yes it's definitely an intentional design choice to have on_fit_begin and on_fit_end be called only in the root estimator for auto-propagated callbacks. I think the main motivation is that these hooks take care of the set-up and tear-down of the callback for the task
It's exactly that.
on_fit_begin and on_fit_end are meant for setup and tear down. An auto-propagated callback is a callback set on the outermost estimator, so should be setup and teared down on it. Propagating it to sub-estimators doesn't mean set it on sub-estimators (i.e. tied to the sub-estimator's context tree) but rather extend the context tree to sub-estimators such that the on_fit_task_end hook is called all the way down (which gives a totally different picture).
There was a problem hiding this comment.
For the record, during an IRL meeting, we decide that @jeremiedbb would open a concurrent PR to refactor the protocol to make callback setup / teardown methods explicit and decoupled from on_fit_begin / on_fit_end.
Let's move that design discussion there. Meanwhile, I think we can proceed with the review and merge of this PR independently and update the tests accordingly once the design evolves.
|
One question maybe for a further PR: what should happen if there are two callbacks and the first callback fails in Here is a test to make my expectation more precise (variation on an existing test but with two callbacks), which fails for @pytest.mark.parametrize("fail_at", ["on_fit_begin", "on_fit_task_end", "on_fit_end"])
def test_callback_error_with_two_callbacks(fail_at):
"""Check that a failing callback is properly teared down."""
callbacks = [FailingCallback(fail_at=fail_at), TestingCallback()]
estimator = MaxIterEstimator().set_callbacks(callbacks)
with pytest.raises(ValueError, match=f"Failing callback failed at {fail_at}"):
estimator.fit()
for cb in callbacks:
assert cb.count_hooks("on_fit_begin") == 1
assert cb.count_hooks("on_fit_end") == 1 |
|
good point, we can call each on_fit_end in a |
ogrisel
left a comment
There was a problem hiding this comment.
Ok for merging as it is and refine the tests later with subsequent PRs that evolve the protocol design.
|
Let's merge this one then! |
Towards #33324
Enhanced a testing callback to record the hook calls and added tests to check the count of callback hooks.
A next step could be to check that the hook calls matches the context tree but let's keep this PR simple.
ping @FrancoisPgm @StefanieSenger