Skip to content

[Callbacks] Make clone keep callbacks#33522

Merged
jeremiedbb merged 12 commits intoscikit-learn:callbacksfrom
jeremiedbb:there-and-back-again-clone-keeps-callback
Mar 13, 2026
Merged

[Callbacks] Make clone keep callbacks#33522
jeremiedbb merged 12 commits intoscikit-learn:callbacksfrom
jeremiedbb:there-and-back-again-clone-keeps-callback

Conversation

@jeremiedbb
Copy link
Copy Markdown
Member

This PR proposes to come back to the state before #33340 (and mostly reverts it), i.e. clone doesn't discard callbacks.

The motivation for this turnaround comes from:

So it seems that a common expectation is that callbacks should work even when set on sub-estimators of meta-estimators or functions that don't implement callback support.

  • Why did we chose to make clone discard callbacks ?:
    my main reason was that callbacks expect the whole computation tree to work optimally. For instance, a progressbar on a bunch of cloned estimators inside a cross-val won't be able to show a progressbar for the cross-val itself, and will spawn that many threads degrading performances.

  • Can callback work (without changing everything) even in meta-estimators that don't support callback ?:
    It seems so.

    • Regular callbacks (e.g. metric monitoring) are already expected to be set on all clones and aggregates results from all of them. What's missing when the meta-estimator doesn't support callback is some information about the task in the meta-estimator, so for instance you can't link results to cv-folds after a cross-val.
    • Propagated callbacks also work as long as their setup is implemented correctly, i.e. setup from different clones should not try to override each other. It was irrelevant before because there was only one setup: before the fit of the outermost estimator.
  • What to do with the many threads spawned by Progressbar ?:

    • It degrades performances because of how it's implemented: it actively waits for content by inspecting the queue which requests the gil, at high frequency. It could be implemented differently, using async for instance and wait passively. I did not do that in the first place because I'm not familiar at all with async programming but we can keep that in mind for a future improvement.
    • In the meantime we can document that not using callbacks in the best setting is not optimal and may degrade performances

In the end I think it's a matter of choice. In this PR I removed the warning raised in clone but we could keep it and change the message to inform users that they're not using callbacks optimally and guide them to the documentation for adding callback support. I removed it because I felt that part of the concerns was the unexpected coupling between clone and the callback infrastructure, so better make things as simple as possible.

Here's an example of what happens with progressbars in such an non-optimal setting:

In [1]: from sklearn.callback.tests._utils import MaxIterEstimator
   ...: from sklearn.callback import ProgressBar
   ...: from sklearn.model_selection import cross_validate
   ...: from sklearn.datasets import load_iris
   ...: 
   ...: X, y = load_iris(return_X_y=True)
   ...: 
   ...: def predict(self, X):
   ...:     return X[:,0].astype(int)
   ...: MaxIterEstimator.predict = predict
   ...: 
   ...: cross_validate(MaxIterEstimator().set_callbacks(ProgressBar()), X, y, scoring="accuracy")
MaxIterEstimator - fit ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
MaxIterEstimator - fit ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
MaxIterEstimator - fit ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
MaxIterEstimator - fit ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
MaxIterEstimator - fit ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00
Out[1]: 
{'fit_time': array([0.08449864, 0.07760406, 0.07409215, 0.07870388, 0.07620287]),
 'score_time': array([0.00337505, 0.00109816, 0.0011282 , 0.00085163, 0.00076461]),
 'test_score': array([0., 0., 0., 0., 0.])}

We don't have the global progress of cross-val and we don't know which bar corresponds to which fold but that still something.

@FrancoisPgm @ogrisel @StefanieSenger @lesteve @adrinjalali let's discuss this option during tomorrow's meeting.

@lesteve
Copy link
Copy Markdown
Member

lesteve commented Mar 11, 2026

Emotional roller-coaster 😅

Copy link
Copy Markdown
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems the things we talked about yesterday actually didn't come up in the test suite here:

  • making sure there are no threads left open after a fit is done
  • the need for a singleton main root node kept at the callback module level instead of the parent node creating it

Otherwise I'm happy with the PR if so far our tests pass.

sklearn/base.py Outdated
Comment on lines +139 to +141
# callbacks are passed by reference because a same instance of a callback can
# be used by multiple clones of the same estimator.
new_object._skl_callbacks = estimator._skl_callbacks
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to keep it as is here, but I also wouldn't be surprised if we figure we actually want to copy the callbacks here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really matter since clone can be called in sub-processes so we effectively get a "copy" in that case.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more efficient to pass by reference whenever we can. I suspect that callback objects grow big if they accumulate data about running operations, so better not copy if we can avoid it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as long as we're not pickling, better to share the same reference.
In sub-processes we get new copies of it, with proxies still pointing to the same managed data structure so that they still share a part of their state

Copy link
Copy Markdown
Member

@StefanieSenger StefanieSenger Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in BaseForest and in LinearModelCV we use cloned sub-estimators in Parallel(n_jobs=self.n_jobs, prefer="threads"). Wouldn't it be a problem then?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, when we can share the full state between threads it's even better !
For instance progressbars can use the same manager for all shared queues.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at test_progressbar_no_callback_support that I just enhanced. It covers the case you're describing

@jeremiedbb
Copy link
Copy Markdown
Member Author

making sure there are no threads left open after a fit is done

We can add a test for that. Right now the existing tests make sure that all teardowns are called (and threads are closed in those teardowns), so it's already something.

the need for a singleton main root node kept at the callback module level instead of the parent node creating it

That (was discussed between Adrin and I during a call) was not a necessity after all and that's good because it would have added a lot of complexity to make make work accross processes.

@jeremiedbb
Copy link
Copy Markdown
Member Author

jeremiedbb commented Mar 11, 2026

Even though I opened this PR, my preference is still that clone shouldn't keep callbacks. I'd even not raise a warning in clone and instead document that if you want callbacks to work, you need to implement it in all the nested estimators.
That being said, it's just a preference but I'm not opposed to the option presented in this PR. Let's find consensus during tomorrow's meeting :)

@jeremiedbb
Copy link
Copy Markdown
Member Author

jeremiedbb commented Mar 11, 2026

Actually everything's not as simple as I stated in the PR description. Reusing the cross-val example but running the cross-val in parallel this time we only see 1 bar progressing at once and then new bars appears and are already finished. This is because we spawn threads in the sub-processes and their stdout is not directly forwarded to the parent process. So the rendering in parallel setting is quite bad.

It also requires some small changes in #33494, but nothing too bad (still a bit annoying like the clean setup/teardown around fit and setting the context as attribute of the estimator)

@StefanieSenger
Copy link
Copy Markdown
Member

Thank you for being open for revision, @jeremiedbb and @FrancoisPgm.

So it seems that a common expectation is that callbacks should work even when set on sub-estimators of meta-estimators or functions that don't implement callback support.

For me, that was not the expectation. I only now see that your main intention behind not cloning callbacks by default was to prevent undefined callback behaviour, whereas fixing the pickling error in the clone (making cloning work) and other issues were only side quests. There was some information loss as to what for we want to define the cloning behaviour, at least for me.

  1. I think we could document that callback behaviour is undefined / not guaranteed to work if users set callbacks, but combine a non-callback supporting object. That seems a reasonable and not surprising limit of a new broad feature.

  2. For preventing undefined callback behaviour I can see that cutting the chain after the first non supporting element is a great idea. I wonder if there is a better place to do so than in clone.

    Alternatively, sub-estimators could check if their parents are a CallbackSupportMixin instance or have a set_callback attribute (not sure how to do this for functions, but probably there are tricks to attach attributes to functions?).

@github-actions github-actions bot added the CI:Linter failure The linter CI is failing on this PR label Mar 11, 2026
@github-actions github-actions bot removed the CI:Linter failure The linter CI is failing on this PR label Mar 11, 2026
@jeremiedbb
Copy link
Copy Markdown
Member Author

jeremiedbb commented Mar 11, 2026

fc833ab shows the necessary changes to ProgressBar to make it work.

The main change is that now the callback must be able to handle:

  • set on a single estimator in the main process and pickled in child processes (so a managed queue)
  • single instance set on several estimators directly in sub-processes (so process-local monitor threads)

It complexifies a bit implementing propagated callbacks properly.

@FrancoisPgm
Copy link
Copy Markdown
Contributor

FrancoisPgm commented Mar 12, 2026

Alternatively, sub-estimators could check if their parents are a CallbackSupportMixin instance or have a set_callback attribute (not sure how to do this for functions, but probably there are tricks to attach attributes to functions?).

I have tried to do that before, to raise a warning when a sub-estimator that supports callbacks is used in a meta-estimator that doesn't. My approach was inspecting the traceback during runtime, identifying if it contains a call of a fit from a non CallbackSupportMixin estimator, but it was a bit clunky, not very robust, and would not work in parallelized settings, so we gave up on that idea. Maybe there is a smarter way to perform the check though.

@lesteve
Copy link
Copy Markdown
Member

lesteve commented Mar 12, 2026

Actually everything's not as simple as I stated in the PR description. Reusing the cross-val example but running the cross-val in parallel this time we only see 1 bar progressing at once and then new bars appears and are already finished.

Just to be sure, do you have some code just so we get an idea how bad the behaviour is. From what you are saying, this doesn't sound horribly bad, but maybe with other callbacks it could be worse?

I tried something like:

from sklearn.base import clone
from joblib import delayed, Parallel
from sklearn.callback.tests._utils import MaxIterEstimator
from sklearn.callback import ProgressBar

n_jobs = 4
backend = "loky"

if __name__ == "__main__":
    def clone_and_fit(estimator):
        clone(estimator).fit()

    def func(estimator, n_fits, n_jobs, backend):
        Parallel(n_jobs=n_jobs, backend=backend)(
            delayed(clone_and_fit)(estimator) for _ in range(n_fits)
        )

    n_fits, max_iter = 5, 7
    callback = ProgressBar()
    estimator = MaxIterEstimator(max_iter=max_iter).set_callbacks(callback)

    func(estimator, n_fits, n_jobs, backend)

and I don't see any major issue personally.

@jeremiedbb
Copy link
Copy Markdown
Member Author

jeremiedbb commented Mar 12, 2026

Just to be sure, do you have some code just so we get an idea how bad the behaviour is. From what you are saying, this doesn't sound horribly bad, but maybe with other callbacks it could be worse?

Your snippet illustrates what I mean. It's just that you need to set a higher number of iterations and fits to actually see something (otherwise it's too fast). Say n_fits=12 and max_iter=100.
Then you'll only see 1 bar progressing and when it's finished, 3 more bars appear and are finished (we didn't see them progress). Then a new bar progresses and the same behavior repeats.
It's because we don't get right away the stdout from the other 3 sub-processes.

The final rendering is fine and shows all bars finished, but that's not what we're interested in with progress bars, we want to see progress.

EDIT: Actually it looks like all bars from subprocesses are stacked during progress and un-stacked when finished.

@adrinjalali
Copy link
Copy Markdown
Member

As far as I can tell, this looks good to me now

@FrancoisPgm
Copy link
Copy Markdown
Contributor

LGTM thanks :)

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small improvement suggestion below, but LGTM overall. I wouldn't mind merging without addressing all of my comments below.

def func(estimator):
Parallel(n_jobs=2)(delayed(clone_and_fit)(estimator) for _ in range(4))

func(MaxIterEstimator().set_callbacks(ProgressBar()))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to inspect that we do not leak rich related resources after the end of the call to func?

I don't want to delay the merge of this PR for this, but maybe we could add a # TODO comment for this.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check for leftover threads and queues.

I case of multiprocessing we can just assume that threads created from the worker processes are killed anyway when the worker is terminated. Actually checking that the thread is properly finished from within the worker would add a lot of complexity.

In case of multithreading I added checks to make sure all threads are finished and queues empty. It also serves as a check for the case above applied to each worker.

sklearn/base.py Outdated
Comment on lines +139 to +141
# callbacks are passed by reference because a same instance of a callback can
# be used by multiple clones of the same estimator.
new_object._skl_callbacks = estimator._skl_callbacks
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more efficient to pass by reference whenever we can. I suspect that callback objects grow big if they accumulate data about running operations, so better not copy if we can avoid it.

@jeremiedbb
Copy link
Copy Markdown
Member Author

I addressed the review comments. I'm merging this one with 2 approvals. Thanks !

@jeremiedbb jeremiedbb merged commit c5b7d8e into scikit-learn:callbacks Mar 13, 2026
25 checks passed
@jeremiedbb jeremiedbb added this to Labs Mar 14, 2026
@jeremiedbb jeremiedbb moved this to In progress in Labs Mar 14, 2026
@github-project-automation github-project-automation bot moved this from In progress to Done in Labs Mar 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants