Skip to content

MiniBatchKMeans not handling sample weights as expected #30750

@snath-xoc

Description

@snath-xoc

Describe the bug

Following up from PR #29907, we realised that when passing sample weights any resampling should be done with weights and replacement before passing through to other operations.

MiniBatchKMeans has a similar bug where minibatch_indices are not resampled with weights but instead weights are passed on to the subsequent minibatch_step which returns resulting in sample weight equivalence not being respected (i.e., repeating and weighting a sample n times behave the same with similar outputs).

Steps/Code to Reproduce

from sklearn.cluster import MiniBatchKMeans, KMeans

import matplotlib.pyplot as plt
from scipy.stats import kstest,ttest_ind
from sklearn.datasets import make_blobs
import numpy as np

rng = np.random.RandomState(0)
    
centres = np.array([[0, 0, 0], [0, 5, 5], [3, 1, 1], [2, 4, 4], [100, 8, 800]])
X, y = make_blobs(
    n_samples=300,
    cluster_std=1,
    centers=centres,
    random_state=10,
)
# Create dataset with repetitions and corresponding sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])
X_resampled_by_weights = np.repeat(X, sample_weight, axis=0)
y_resampled_by_weights = np.repeat(y,sample_weight)

predictions_sw = []
predictions_dup = []
predictions_sw_mini = []
predictions_dup_mini = []

prediction_rank = np.argsort(y)[-1:]

for seed in range(100):

    ## Fit estimator
    est_sw = KMeans(random_state=seed,n_clusters=5).fit(X,y,sample_weight=sample_weight)
    est_dup = KMeans(random_state=seed,n_clusters=5).fit(X_resampled_by_weights,y_resampled_by_weights)
    est_sw_mini = MiniBatchKMeans(random_state=seed,n_clusters=5).fit(X,y,sample_weight=sample_weight)
    est_dup_mini = MiniBatchKMeans(random_state=seed,n_clusters=5).fit(X_resampled_by_weights,y_resampled_by_weights)
    
    ##Get predictions
    predictions_sw.append(est_sw.predict(X[prediction_rank]))
    predictions_dup.append(est_dup.predict(X[prediction_rank]))
    predictions_sw_mini.append(est_sw_mini.predict(X[prediction_rank]))
    predictions_dup_mini.append(est_dup_mini.predict(X[prediction_rank]))

fig = plt.figure(figsize=(10,5))
ax1=fig.add_subplot(1,2,1)
ax2=fig.add_subplot(1,2,2)

predictions_sw = np.asarray(predictions_sw).flatten()
predictions_dup = np.asarray(predictions_dup).flatten()
ax1.hist(predictions_sw)
ax1.hist(predictions_dup,alpha=0.5)
ax1.set_title("KMeans: %.2f"%(kstest(predictions_sw,predictions_dup).pvalue))

predictions_sw_mini = np.asarray(predictions_sw_mini).flatten()
predictions_dup_mini = np.asarray(predictions_dup_mini).flatten()
ax2.hist(predictions_sw_mini,label="weighted")
ax2.hist(predictions_dup_mini,label="repeated",alpha=0.5)
ax2.set_title("MiniBatchKMeans: %.2f"%(kstest(predictions_sw_mini,predictions_dup_mini).pvalue))
plt.legend()

Expected Results

KMeans and Minibatch KMeans return similar histograms

Actual Results

Image

Versions

System:
    python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
   machine: macOS-14.3-arm64-arm-64bit

Python dependencies:
      sklearn: 1.7.dev0
          pip: 24.0
   setuptools: 75.8.0
        numpy: 2.0.0
        scipy: 1.14.0
       Cython: 3.0.10
       pandas: 2.2.2
   matplotlib: 3.9.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
       filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libopenblas.0.dylib
        version: 0.3.27
threading_layer: openmp
   architecture: VORTEX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libomp
       filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions