Skip to content

Fix MinibatchKMeans minibatch_indices creation#30751

Merged
jeremiedbb merged 40 commits intoscikit-learn:mainfrom
snath-xoc:fix_minibatchkmeans
Feb 6, 2026
Merged

Fix MinibatchKMeans minibatch_indices creation#30751
jeremiedbb merged 40 commits intoscikit-learn:mainfrom
snath-xoc:fix_minibatchkmeans

Conversation

@snath-xoc
Copy link
Copy Markdown
Contributor

@snath-xoc snath-xoc commented Feb 2, 2025

Reference Issues/PRs

Tries (although not successfully) to fix #30750

What does this implement/fix? Explain your changes.

When creating minibatch_indices before the mini_batch_step we employ weighted resampling (with replacement)

Any other comments?

This does not solve the issue, I am still getting histograms similar to as shown in the issue, even when using init="random". I did not change the sample weight passing into the mini_batch_step, so currently they are double accounted for. This is probably an issue however I see that the sample weight it used in the _minibatch_update_dense function. Any further thoughts on this would help.

TO DO:

  • sample weights are double accounted for as passed on to the mini batch step after selecting mini-batch indices. Need further discussion to see if we can leave them out of the mini batch step altorgether.
  • I had to add dummy sample weights of ones to minibatch step, otherwise I was getting errors and exits during testing. It turns out the _check_sample_weight returns an array of ones with the X dtype when sample_weight is None. Please check if the current implementation makes sense now.
  • Test is still not returning similar results to KMeans, with init="random" both methods return results not seemingly respecting sample weight equivalence.
  • test_scaled_weights is broken now, need to fix

@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 2, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 980f146. Link to the linter CI: here

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Feb 3, 2025

I did not change the sample weight passing into the mini_batch_step, so currently they are double accounted for.

I agree, let's not reuse the weights in the computation of the step as they are already used to sample the minibatch. Let's remove this and simplify the code of _minibatch_step accordingly.

Copy link
Copy Markdown
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This slightly changes the behavior of MiniBatchKMeans, even without using sample weights so it needs 2 changelog entries. One for the bug fix and one the change of behavior.

Some doctests fail because of the change. You just need to use the new results values that these snippets produce as expected results.

Comment on lines +2149 to +2152
# Note, I am not sure how sample weights are used here
# So left it in, it seems like the weight sums are updated using
# sample weights so need some help here to understand the
# _minibatch_update_dense/sparse code
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like in KBinsDiscretizer, we should not use sample weights after using them for weighted sampling.
So let's pass an array of ones as sample weights here.

I am not sure how sample weights are used here
So left it in, it seems like the weight sums are updated using
sample weights so need some help here to understand

weight_sum is the sum of weights of all points belonging to each clusters. It's used to track clusters where there are very few points (more precisely points that add up to a small weight) and reassign them to a different cluster.

I don't think we have to modify _mini_batch_step. It's still useful that it handles sample weights because it's also used in partial_fit where there is no sampling so sample weights must be passed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have to modify _mini_batch_step. It's still useful that it handles sample weights because it's also used in partial_fit where there is no sampling so sample weights must be passed.

That's a very good point indeed.

Copy link
Copy Markdown
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the convergence check is also broken (_mini_batch_convergence). It uses n_samples instead of sample_weight.sum(). Let's first make sure that the rest is fixed. You can disable convergence check by setting max_no_improvement=None and tol=0.

Then when are confident that sample weights are correctly handled by the core of the algorithm, we'll enable early convergence check again and fix it.

snath-xoc and others added 3 commits February 7, 2025 01:43
@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Feb 7, 2025

I think that the convergence check is also broken (_mini_batch_convergence). It uses n_samples instead of sample_weight.sum(). Let's first make sure that the rest is fixed. You can disable convergence check by setting max_no_improvement=None and tol=0.

Indeed, it depends both on n_samples and self._batch_size. I wonder if the value of self._batch_size should be rescaled by X.shape[0] / sample_weight.sum() before we start the iteration.

@jeremiedbb
Copy link
Copy Markdown
Member

jeremiedbb commented Feb 7, 2025

I wonder if the value of self._batch_size should be rescaled

I don't think so: since we're sampling with weight, then the self._batch_size sampled points have unit weight so their total weight is self._batch_size.

EDIT:

X.shape[0] / sample_weight.sum() before we start the iteration

I read your comment too quickly. I don't think so either: since we're doing weighted sampling, the sampled points have unit weights, so we need the same batch size to be equivalent to the repeated case

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Feb 7, 2025

I tried to run our statistical testing notebook against this branch and the test passes (while it fails on main)!

✅ MiniBatchKMeans: (min_p_value: 0.239, mean_p_value=0.724)

EDIT: using the following config:

    MiniBatchKMeans: {"reassignment_ratio": 0.9},

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Feb 7, 2025

But surprisingly, it fails with:

❌ MiniBatchKMeans: (min_p_value: 0.001, mean_p_value=0.369)

with

    MiniBatchKMeans: {"max_no_improvement":None, "tol": 0},

or:

❌ MiniBatchKMeans: (min_p_value: 0.016, mean_p_value=0.616)

with the default hparams (default convergence criterion and default reassignment_ratio=0.01).

but it passes (barely) with:

✅ MiniBatchKMeans: (min_p_value: 0.071, mean_p_value=0.537)

with

    MiniBatchKMeans: {"max_no_improvement":None, "tol": 0, "reassignment_ratio": 0.9},

so there might still be a problem only visible with lower values of "reassignment_ratio".

snath-xoc and others added 2 commits February 8, 2025 11:52
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Copy Markdown
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found another issue with the current implementation:

  • We use n_samples to compute the number of steps to run, i.e. the number of minibatches to process.
  • The max_iter parameter says that we run the necessary number of steps in order to loop through the whole dataset max_iter times.
  • Using n_samples to compute the total number of steps leads to a smaller number of steps in the weighted case than in the repeated case.

The suggestion below ensures that both run the same number of steps. It requires some adjustments to compute the fitted attributes at the end, n_iter_ and inertia_.

One major drawback is that it breaks the equivalence with scaled sample weights, i.e equivalence between fit(X, sample_weight=1) and fit'(sample_weight=2). I haven't been able to find a way to preserve both equivalence properties. Unless changing the meaning of max_iter or something.

Another thing. With this modification and the following parameters: max_no_improvement=None, tol=0, n_init=1, reassignment_ratio=0, init_size=100000 (to make sure to take all points); the statistical test passes, but with a min pvalue quite small, around 0.10 to 0.30, so not that great. It maybe means that we're still missing something or it could be inherent to MiniBatchKMeans that is not a convex problem and it's easy that small modifications in the input end up in a different local minimum.

Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
@snath-xoc
Copy link
Copy Markdown
Contributor Author

snath-xoc commented Feb 13, 2025

Thank you @jeremiedbb for the insights

One major drawback is that it breaks the equivalence with scaled sample weights, i.e equivalence between fit(X, sample_weight=1) and fit'(sample_weight=2). I haven't been able to find a way to preserve both equivalence properties. Unless changing the meaning of max_iter or something.

EDIT: yes I see the tests failing, how would you suggest changing max_iter?

Another thing. With this modification and the following parameters: max_no_improvement=None, tol=0, n_init=1, reassignment_ratio=0, init_size=100000 (to make sure to take all points); the statistical test passes, but with a min pvalue quite small, around 0.10 to 0.30, so not that great. It maybe means that we're still missing something or it could be inherent to MiniBatchKMeans that is not a convex problem and it's easy that small modifications in the input end up in a different local minimum.

Hmmmm interesting, passes on my side as well. It may be what you say, although I am wondering how the random_reassign interacts with sample weight as we do n_since_last_reassign+=batch_size. This should be O.K. since we resample with weights. Can't really think of anything else that I see in the code that is causing this discrepancy

EDIT: I did some tests it is not a problem with random_reassign

@snath-xoc
Copy link
Copy Markdown
Contributor Author

Oh no sorry I seem to have pushed and broke the linter as well

@github-actions github-actions bot removed the CI:Linter failure The linter CI is failing on this PR label Jan 23, 2026
@snath-xoc
Copy link
Copy Markdown
Contributor Author

Agh so I added the sample weight scaling test to the estimator checks a while ago for this PR and it's failing on a lot of estimators, perhaps I should remove it for now... I think we had discussed perhaps adding the scaling relationship to the sample-weight-audit-nondet repo?

Comment on lines +2157 to +2158
# Rescaling step for sample weights otherwise doesn not pass test_scaled_weights
n_steps = int((self.max_iter * n_effective_samples)) // (self._batch_size)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's actually an easy way to make MinibatchKMeans pass the scaled weights test: make n_steps independent of the weights as it was before (since max_iter doesn't take weights into account).

That is

n_steps = (self.max_iter * n_samples) // self._batch_size

note that this is the same as

max_iter * sum(sample_weigths) / (batch_size * mean(sample_weights))

which is a reasonable expectation.

That way, the name number of batches are processed no matter the scaling of sample weights.
The counterpart is that the total weight seen during fit is scaled by the sample weight scaling but I don't think that it's an issue.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The counterpart is that the total weight seen during fit is scaled by the sample weight scaling but I don't think that it's an issue.

Even that is expected to me actually. If we scale weights by a factor of 2, the total weight of the full dataset is multiplied by 2 and so a full iteration should see twice the total weight.

So I actually think that there's no issue with defining n_steps independently of sample weights

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @jeremiedbb for that, glad to see that the weights scaling test can now pass with this.

@jeremiedbb
Copy link
Copy Markdown
Member

I pushed a commit to implement what I tried to explain in my previous comment. I wanted to test https://github.com/snath-xoc/sample-weight-audit-nondet against this branch but there's a bug for clusterers. We should merge snath-xoc/sample-weight-audit-nondet#36 first.

@jeremiedbb
Copy link
Copy Markdown
Member

Here's the result of the sample weight test on the fixed branch (snath-xoc/sample-weight-audit-nondet#36)
image

@jeremiedbb
Copy link
Copy Markdown
Member

I think this PR is good to go. It just needs a changelog entry.

@snath-xoc
Copy link
Copy Markdown
Contributor Author

Thank you @jeremiedbb just merged the branch, this looks good to go now?

@snath-xoc
Copy link
Copy Markdown
Contributor Author

@ogrisel and @adrinjalali this should be good to merge now?

@jeremiedbb
Copy link
Copy Markdown
Member

It just needs a changelog entry before merging

@ogrisel
Copy link
Copy Markdown
Member

ogrisel commented Feb 4, 2026

@snath-xoc the CI is still red because of the missing changelog entry: https://github.com/scikit-learn/scikit-learn/blob/main/doc/whats_new/upcoming_changes/README.md

@snath-xoc
Copy link
Copy Markdown
Contributor Author

It just needs a changelog entry before merging

Sorry I had forgot about that added it in now

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM besides the following:


n_steps = (self.max_iter * n_samples) // self._batch_size

n_effective_samples = np.sum(sample_weight)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather rename this variable to "sum_of_weights" as it's more descriptive.

n_effective_samples could be interpreted differently in different contexts, in particular in the presence of repeated data points.

Copy link
Copy Markdown
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed a commit to clean-up the formatting and removed a bit of implementation detail.

LGTM. Thanks !

@snath-xoc
Copy link
Copy Markdown
Contributor Author

renaming all done, this should be good to merge @adrinjalali

@jeremiedbb jeremiedbb merged commit 0932d7e into scikit-learn:main Feb 6, 2026
40 checks passed
@github-project-automation github-project-automation bot moved this from In progress to Done in Labs Feb 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MiniBatchKMeans not handling sample weights as expected

4 participants