55# Authors: The scikit-learn developers
66# SPDX-License-Identifier: BSD-3-Clause
77
8+ import re
89from itertools import cycle , product
910
1011import joblib
@@ -696,6 +697,37 @@ def test_warning_bootstrap_sample_weight():
696697 reg .fit (X , y , sample_weight = sample_weight )
697698
698699
700+ def test_invalid_sample_weight_max_samples_bootstrap_combinations ():
701+ X , y = iris .data , iris .target
702+
703+ # Case 1: small weights and fractional max_samples would lead to sampling
704+ # less than 1 sample, which is not allowed.
705+ clf = BaggingClassifier (max_samples = 1.0 )
706+ sample_weight = np .ones_like (y ) / (2 * len (y ))
707+ expected_msg = (
708+ r"The total sum of sample weights is 0.5(\d*), which prevents resampling with "
709+ r"a fractional value for max_samples=1\.0\. Either pass max_samples as an "
710+ r"integer or use a larger sample_weight\."
711+ )
712+ with pytest .raises (ValueError , match = expected_msg ):
713+ clf .fit (X , y , sample_weight = sample_weight )
714+
715+ # Case 2: large weights and bootstrap=False would lead to sampling without
716+ # replacement more than the number of samples, which is not allowed.
717+ clf = BaggingClassifier (bootstrap = False , max_samples = 1.0 )
718+ sample_weight = np .ones_like (y )
719+ sample_weight [- 1 ] = 2
720+ expected_msg = re .escape (
721+ "max_samples=151 must be <= n_samples=150 to be able to sample without "
722+ "replacement."
723+ )
724+ with pytest .raises (ValueError , match = expected_msg ):
725+ with pytest .warns (
726+ UserWarning , match = "When fitting BaggingClassifier with sample_weight"
727+ ):
728+ clf .fit (X , y , sample_weight = sample_weight )
729+
730+
699731class EstimatorAcceptingSampleWeight (BaseEstimator ):
700732 """Fake estimator accepting sample_weight"""
701733
@@ -724,8 +756,9 @@ def predict(self, X):
724756@pytest .mark .parametrize ("bagging_class" , [BaggingRegressor , BaggingClassifier ])
725757@pytest .mark .parametrize ("accept_sample_weight" , [False , True ])
726758@pytest .mark .parametrize ("metadata_routing" , [False , True ])
759+ @pytest .mark .parametrize ("max_samples" , [10 , 0.8 ])
727760def test_draw_indices_using_sample_weight (
728- bagging_class , accept_sample_weight , metadata_routing
761+ bagging_class , accept_sample_weight , metadata_routing , max_samples
729762):
730763 X = np .arange (100 ).reshape (- 1 , 1 )
731764 y = np .repeat ([0 , 1 ], 50 )
@@ -739,7 +772,15 @@ def test_draw_indices_using_sample_weight(
739772 base_estimator = EstimatorRejectingSampleWeight ()
740773
741774 n_samples , n_features = X .shape
742- max_samples = 10
775+
776+ if isinstance (max_samples , float ):
777+ # max_samples passed as a fraction of the input data. Since
778+ # sample_weight are provided, the effective number of samples is the
779+ # sum of the sample weights.
780+ expected_integer_max_samples = int (max_samples * sample_weight .sum ())
781+ else :
782+ expected_integer_max_samples = max_samples
783+
743784 with config_context (enable_metadata_routing = metadata_routing ):
744785 # TODO(slep006): remove block when default routing is implemented
745786 if metadata_routing and accept_sample_weight :
@@ -748,7 +789,7 @@ def test_draw_indices_using_sample_weight(
748789 bagging .fit (X , y , sample_weight = sample_weight )
749790 for estimator , samples in zip (bagging .estimators_ , bagging .estimators_samples_ ):
750791 counts = np .bincount (samples , minlength = n_samples )
751- assert sum (counts ) == len (samples ) == max_samples
792+ assert sum (counts ) == len (samples ) == expected_integer_max_samples
752793 # only indices 4 and 5 should appear
753794 assert np .isin (samples , [4 , 5 ]).all ()
754795 if accept_sample_weight :
@@ -760,8 +801,8 @@ def test_draw_indices_using_sample_weight(
760801 assert_allclose (estimator .sample_weight_ , counts )
761802 else :
762803 # sampled indices represented through indexing
763- assert estimator .X_ .shape == (max_samples , n_features )
764- assert estimator .y_ .shape == (max_samples ,)
804+ assert estimator .X_ .shape == (expected_integer_max_samples , n_features )
805+ assert estimator .y_ .shape == (expected_integer_max_samples ,)
765806 assert_allclose (estimator .X_ , X [samples ])
766807 assert_allclose (estimator .y_ , y [samples ])
767808
0 commit comments