-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Closed
Labels
Description
Describe the bug
BinMapper under _hist_gradient_boosting does not accept sample weights as input leading to mismatch of bin thresholds outputted when calculating weighted versus repeated samples. Linked to Issue #27117
Steps/Code to Reproduce
from sklearn.ensemble._hist_gradient_boosting import binning
from sklearn.datasets import make_regression
import numpy as np
n_samples = 50
n_features = 2
rng = np.random.RandomState(42)
X, y = make_regression(
n_samples=n_samples,
n_features=n_features,
n_informative=n_features,
random_state=0,
)
# Create dataset with repetitions and corresponding sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])
X_resampled_by_weights = np.repeat(X, sample_weight, axis=0)
bins_fit_weighted = binning._BinMapper(255).fit(X)
bins_fit_resampled = binning._BinMapper(255).fit(X_resampled_by_weights)
np.testing.assert_allclose(bins_fit_resampled.bin_thresholds_, bins_fit_weighted.bin_thresholds_)Expected Results
No error thrown
Actual Results
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
(shapes (2, 47), (2, 49) mismatch)
ACTUAL: array([[-2.12963 , -1.668234, -1.622048, -1.433347, -1.208973, -1.117951,
-1.059653, -0.977926, -0.901382, -0.891626, -0.879291, -0.841972,
-0.742803, -0.653391, -0.572564, -0.510229, -0.456415, -0.395252,...
DESIRED: array([[-2.12963 , -1.668234, -1.622048, -1.433347, -1.208973, -1.117951,
-1.059653, -0.977926, -0.901382, -0.891626, -0.879291, -0.841972,
-0.742803, -0.653391, -0.572564, -0.510229, -0.456415, -0.395252,...
Versions
System:
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
machine: macOS-14.3-arm64-arm-64bit
Python dependencies:
sklearn: 1.6.dev0
pip: 24.0
setuptools: 70.1.1
numpy: 2.0.0
scipy: 1.14.0
Cython: 3.0.10
pandas: None
matplotlib: 3.9.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
...
num_threads: 8
prefix: libomp
filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
version: NoneReactions are currently unavailable