-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Is sample_weight missing when calling cross_val_predict in stacked model? #16537
Copy link
Copy link
Closed
Closed
Copy link
Labels
Description
Describe the bug
When training a meta-classifier on the cross-validated folds, sample_weight is not passed to cross_val_predict via fit_params.
_BaseStacking fits all base estimators with the sample_weight vector. _BaseStacking also fits the final/meta-estimator with the sample_weight vector.
When we call cross_val_predict to fit and predict the base models on folds of the input data, the sample weights are not passed.
Steps/Code to Reproduce
from sklearn.datasets import make_classification
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import StackingClassifier
class PrintingDummyClassifier(DummyClassifier):
def fit(self, X, y, sample_weight=None):
if sample_weight:
print(len(sample_weight))
else:
print("Passed sample_weight=None")
return super().fit(X, y, sample_weight=sample_weight)
X, y = make_classification(
n_samples=26000, n_features=5, n_classes=4, n_informative=3, random_state=0
)
sample_weight = [1.0] * X.shape[0]
clf = StackingClassifier(estimators=[("dummy", PrintingDummyClassifier())], final_estimator=DummyClassifier())
clf.fit(X, y, sample_weight)
Expected Results
I would expect to see the length of the sample weight vector printed on each invokation of the base estimator's fit method.
Actual Results
Sample weights are only passed when fitting the models on the entire data set. The models used as input to the final estimator are not fit using sample weights.
Program output:
26000
Passed sample_weight=None
Passed sample_weight=None
Passed sample_weight=None
Passed sample_weight=None
Passed sample_weight=None
26000
Versions
System:
python: 3.6.9 (default, Jan 24 2020, 14:49:06) [GCC 4.2.1 Compatible Apple LLVM 11.0.0 (clang-1100.0.20.17)]
executable: /Users/wderose/.pyenv/versions/3.6.9/envs/env/bin/python3.6
machine: Darwin-18.7.0-x86_64-i386-64bit
Python dependencies:
pip: 20.0.2
setuptools: 40.6.2
sklearn: 0.22.1
numpy: 1.18.1
scipy: 1.4.1
Cython: None
pandas: 0.25.3
matplotlib: 3.1.2
joblib: 0.14.0
Built with OpenMP: True
cc: @caioaao @glemaitre
Reactions are currently unavailable