Skip to content

BinMapper within HGBT does not handle sample weights #29640

@snath-xoc

Description

@snath-xoc

Describe the bug

BinMapper under _hist_gradient_boosting does not accept sample weights as input leading to mismatch of bin thresholds outputted when calculating weighted versus repeated samples. Linked to Issue #27117

Steps/Code to Reproduce

from sklearn.ensemble._hist_gradient_boosting import binning

from sklearn.datasets import make_regression
import numpy as np

n_samples = 50  
n_features = 2
rng = np.random.RandomState(42)
    
X, y = make_regression(
    n_samples=n_samples,
    n_features=n_features,
    n_informative=n_features,
    random_state=0,
)

# Create dataset with repetitions and corresponding sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])
X_resampled_by_weights = np.repeat(X, sample_weight, axis=0)

bins_fit_weighted = binning._BinMapper(255).fit(X)
bins_fit_resampled = binning._BinMapper(255).fit(X_resampled_by_weights)

np.testing.assert_allclose(bins_fit_resampled.bin_thresholds_, bins_fit_weighted.bin_thresholds_)

Expected Results

No error thrown

Actual Results

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

(shapes (2, 47), (2, 49) mismatch)
 ACTUAL: array([[-2.12963 , -1.668234, -1.622048, -1.433347, -1.208973, -1.117951,
        -1.059653, -0.977926, -0.901382, -0.891626, -0.879291, -0.841972,
        -0.742803, -0.653391, -0.572564, -0.510229, -0.456415, -0.395252,...
 DESIRED: array([[-2.12963 , -1.668234, -1.622048, -1.433347, -1.208973, -1.117951,
        -1.059653, -0.977926, -0.901382, -0.891626, -0.879291, -0.841972,
        -0.742803, -0.653391, -0.572564, -0.510229, -0.456415, -0.395252,...

Versions

System:
    python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
   machine: macOS-14.3-arm64-arm-64bit

Python dependencies:
      sklearn: 1.6.dev0
          pip: 24.0
   setuptools: 70.1.1
        numpy: 2.0.0
        scipy: 1.14.0
       Cython: 3.0.10
       pandas: None
   matplotlib: 3.9.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
...
    num_threads: 8
         prefix: libomp
       filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions