Skip to content

Is sample_weight missing when calling cross_val_predict in stacked model? #16537

@wderose

Description

@wderose

Describe the bug

When training a meta-classifier on the cross-validated folds, sample_weight is not passed to cross_val_predict via fit_params.

_BaseStacking fits all base estimators with the sample_weight vector. _BaseStacking also fits the final/meta-estimator with the sample_weight vector.

When we call cross_val_predict to fit and predict the base models on folds of the input data, the sample weights are not passed.

Steps/Code to Reproduce

from sklearn.datasets import make_classification
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import StackingClassifier


class PrintingDummyClassifier(DummyClassifier):
    def fit(self, X, y, sample_weight=None):
        if sample_weight:
            print(len(sample_weight))
        else:
            print("Passed sample_weight=None")
        return super().fit(X, y, sample_weight=sample_weight)


X, y = make_classification(
    n_samples=26000, n_features=5, n_classes=4, n_informative=3, random_state=0
)
sample_weight = [1.0] * X.shape[0]
clf = StackingClassifier(estimators=[("dummy", PrintingDummyClassifier())], final_estimator=DummyClassifier())
clf.fit(X, y, sample_weight)

Expected Results

I would expect to see the length of the sample weight vector printed on each invokation of the base estimator's fit method.

Actual Results

Sample weights are only passed when fitting the models on the entire data set. The models used as input to the final estimator are not fit using sample weights.

Program output:

26000
Passed sample_weight=None
Passed sample_weight=None
Passed sample_weight=None
Passed sample_weight=None
Passed sample_weight=None
26000

Versions

System:
    python: 3.6.9 (default, Jan 24 2020, 14:49:06)  [GCC 4.2.1 Compatible Apple LLVM 11.0.0 (clang-1100.0.20.17)]
executable: /Users/wderose/.pyenv/versions/3.6.9/envs/env/bin/python3.6
   machine: Darwin-18.7.0-x86_64-i386-64bit

Python dependencies:
       pip: 20.0.2
setuptools: 40.6.2
   sklearn: 0.22.1
     numpy: 1.18.1
     scipy: 1.4.1
    Cython: None
    pandas: 0.25.3
matplotlib: 3.1.2
    joblib: 0.14.0

Built with OpenMP: True

cc: @caioaao @glemaitre

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions