-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Closed
Labels
Description
Describe the bug
Following up from PR #29907, we realised that when passing sample weights any resampling should be done with weights and replacement before passing through to other operations.
MiniBatchKMeans has a similar bug where minibatch_indices are not resampled with weights but instead weights are passed on to the subsequent minibatch_step which returns resulting in sample weight equivalence not being respected (i.e., repeating and weighting a sample n times behave the same with similar outputs).
Steps/Code to Reproduce
from sklearn.cluster import MiniBatchKMeans, KMeans
import matplotlib.pyplot as plt
from scipy.stats import kstest,ttest_ind
from sklearn.datasets import make_blobs
import numpy as np
rng = np.random.RandomState(0)
centres = np.array([[0, 0, 0], [0, 5, 5], [3, 1, 1], [2, 4, 4], [100, 8, 800]])
X, y = make_blobs(
n_samples=300,
cluster_std=1,
centers=centres,
random_state=10,
)
# Create dataset with repetitions and corresponding sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])
X_resampled_by_weights = np.repeat(X, sample_weight, axis=0)
y_resampled_by_weights = np.repeat(y,sample_weight)
predictions_sw = []
predictions_dup = []
predictions_sw_mini = []
predictions_dup_mini = []
prediction_rank = np.argsort(y)[-1:]
for seed in range(100):
## Fit estimator
est_sw = KMeans(random_state=seed,n_clusters=5).fit(X,y,sample_weight=sample_weight)
est_dup = KMeans(random_state=seed,n_clusters=5).fit(X_resampled_by_weights,y_resampled_by_weights)
est_sw_mini = MiniBatchKMeans(random_state=seed,n_clusters=5).fit(X,y,sample_weight=sample_weight)
est_dup_mini = MiniBatchKMeans(random_state=seed,n_clusters=5).fit(X_resampled_by_weights,y_resampled_by_weights)
##Get predictions
predictions_sw.append(est_sw.predict(X[prediction_rank]))
predictions_dup.append(est_dup.predict(X[prediction_rank]))
predictions_sw_mini.append(est_sw_mini.predict(X[prediction_rank]))
predictions_dup_mini.append(est_dup_mini.predict(X[prediction_rank]))
fig = plt.figure(figsize=(10,5))
ax1=fig.add_subplot(1,2,1)
ax2=fig.add_subplot(1,2,2)
predictions_sw = np.asarray(predictions_sw).flatten()
predictions_dup = np.asarray(predictions_dup).flatten()
ax1.hist(predictions_sw)
ax1.hist(predictions_dup,alpha=0.5)
ax1.set_title("KMeans: %.2f"%(kstest(predictions_sw,predictions_dup).pvalue))
predictions_sw_mini = np.asarray(predictions_sw_mini).flatten()
predictions_dup_mini = np.asarray(predictions_dup_mini).flatten()
ax2.hist(predictions_sw_mini,label="weighted")
ax2.hist(predictions_dup_mini,label="repeated",alpha=0.5)
ax2.set_title("MiniBatchKMeans: %.2f"%(kstest(predictions_sw_mini,predictions_dup_mini).pvalue))
plt.legend()Expected Results
KMeans and Minibatch KMeans return similar histograms
Actual Results
Versions
System:
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
machine: macOS-14.3-arm64-arm-64bit
Python dependencies:
sklearn: 1.7.dev0
pip: 24.0
setuptools: 75.8.0
numpy: 2.0.0
scipy: 1.14.0
Cython: 3.0.10
pandas: 2.2.2
matplotlib: 3.9.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libopenblas.0.dylib
version: 0.3.27
threading_layer: openmp
architecture: VORTEX
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libomp
filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
version: NoneReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Done
