Skip to content

FIX Draw indices using sample_weight in Random Forests#31529

Merged
ogrisel merged 61 commits intoscikit-learn:mainfrom
antoinebaker:random_forest_sample_weight
Jan 16, 2026
Merged

FIX Draw indices using sample_weight in Random Forests#31529
ogrisel merged 61 commits intoscikit-learn:mainfrom
antoinebaker:random_forest_sample_weight

Conversation

@antoinebaker
Copy link
Copy Markdown
Contributor

@antoinebaker antoinebaker commented Jun 12, 2025

Part of #16298. Similar to #31414 (Bagging estimators) but for Forest estimators.

Also fixes #28507.

What does this implement/fix? Explain your changes.

When subsampling is activated (bootstrap=True), sample_weight are now used as probabilities to draw the indices. Forest estimators then pass the statistical repeated/weighted equivalence test.

Comments

This PR does not fix Forest estimators when bootstrap=False (no subsampling). sample_weight are still passed to the decision trees. Forest estimators then fail the statistical repeated/weighted equivalence test because the individual trees
also fail this test (probably because of tied splits in decision trees #23728).

TODO

@github-actions
Copy link
Copy Markdown

github-actions bot commented Jun 12, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: d47066f. Link to the linter CI: here

@antoinebaker
Copy link
Copy Markdown
Contributor Author

The forest estimators now pass the statistical repeated/weighted equivalence test, for example
image

@antoinebaker
Copy link
Copy Markdown
Contributor Author

Relative (float) max_samples, with the new meaning of drawing max_samples * sw_sum indices as done in #31414 , also passes the statistical repeated/weighted equivalence test
image

@antoinebaker
Copy link
Copy Markdown
Contributor Author

The class_weight="balanced" option, now taking the sample_weight into account as in #30057, now passes the statistical repeated/weighted equivalence test
image

@antoinebaker
Copy link
Copy Markdown
Contributor Author

The class_weight="balanced_subsampling" also passes, in that case sample_weight are used to draw the indices, the class_weight are then computed on the bootstraped sample for every grown tree and passed as sample_weight to the tree fit.
image

@antoinebaker antoinebaker marked this pull request as ready for review June 27, 2025 10:14
Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't the time to finish my review today, but this looks great: I tried running the notebook of https://github.com/snath-xoc/sample-weight-audit-nondet/ against this branch and I confirm the statistical tests pass for RandomForestClassifier/Regressor and ExtraTreesClassifier/Regressor.

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More feedback.

antoinebaker and others added 2 commits July 8, 2025 17:10
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Copy Markdown
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a light pass. Mostly LGTM! Great work @antoinebaker

@lucyleeow
Copy link
Copy Markdown
Member

@ogrisel do you want to weigh in on the max_sample behaviour: #28507

Not sure whether we should allow float to be >1 as well here, or deal with both float and int in a separate PR?

If we do change behaviour here, it may be worth mentioning it in the changelog. Note that you can only have one bullet point per changelog file, but you can have any number of nested bullet points.

@adrinjalali adrinjalali moved this from In progress to In progress - High Priority in Labs Jan 6, 2026
@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Jan 6, 2026

I am fine with allowing max_samples > 1.0 (when float) but let's do that in a separate PR to make it easier to review (with a dedicated changelog entry and tests).

EDIT: I changed my mind after re-reading the discussion, see below.

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did another pass, and LGTM.

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Jan 6, 2026

Actually, we probably need to add support for max_samples > 1.0 when float as part of this PR to preserve consistency with max_samples > n_samples when passed integer.

At the moment, when calling:

RandomForestClassifier(max_samples=1.1).fit(X, y)

we get:

InvalidParameterError: The 'max_samples' parameter of RandomForestClassifier must be None, a float in the range (0.0, 1.0] or an int in the range [1, inf). Got 1.6 instead.

but the following does not raise anymore:

RandomForestClassifier(max_samples=int(1.1 * X.shape[0])).fit(X, y)

Which is weird, I agree. Let's make that behavior change consistent and explicitly documented and tested as part of this PR.

@lucyleeow
Copy link
Copy Markdown
Member

Let's also mark as fixing #28507 !

@antoinebaker
Copy link
Copy Markdown
Contributor Author

Thanks for the reviews @lucyleeow @adam2392 @ogrisel. max_samples > 1.0 is now supported and tested in test_max_samples_geq_one.

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @antoinebaker. LGTM besides the following nits.

antoinebaker and others added 3 commits January 12, 2026 15:52
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Jan 14, 2026

@lucyleeow @adam2392 @cakedev0 ok for merge?

Copy link
Copy Markdown
Contributor

@cakedev0 cakedev0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't reviewed the latest changes, but I read the discussions and commit messages and it looked good 👍

And for the rest, all good for me.

Copy link
Copy Markdown
Member

@lucyleeow lucyleeow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits and a question, but LGTM!

sample_weight : array of shape (n_samples,) or None
Sample weights. The frequency semantics of :term:`sample_weight` is
guaranteed when `max_samples` is a float or integer, but not when
`max_samples` is None.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add back the line about "the effective bootstrap size is no longer guaranteed to be equivalent."?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it in a dedicated Notes section, as it was a bit lengthy.

antoinebaker and others added 3 commits January 16, 2026 09:37
@ogrisel ogrisel merged commit ce1b377 into scikit-learn:main Jan 16, 2026
38 checks passed
@github-project-automation github-project-automation bot moved this from In progress - High Priority to Done in Labs Jan 16, 2026
@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Jan 16, 2026

Merged! Thanks all for your work on getting this in!

dschult pushed a commit to dschult/scikit-learn that referenced this pull request Jan 25, 2026
…31529)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Lucy Liu <jliu176@gmail.com>
Co-authored-by: Arthur Lacote <arthur.lcte@gmail.com>
dschult pushed a commit to dschult/scikit-learn that referenced this pull request Jan 25, 2026
…31529)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Lucy Liu <jliu176@gmail.com>
Co-authored-by: Arthur Lacote <arthur.lcte@gmail.com>
TejasAnalyst pushed a commit to TejasAnalyst/scikit-learn that referenced this pull request Feb 10, 2026
…31529)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Lucy Liu <jliu176@gmail.com>
Co-authored-by: Arthur Lacote <arthur.lcte@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow RandomForest* and ExtraTrees* to have a higher max_samples than 1.0 when bootstrap=True

7 participants