Skip to content

ENH Add sample_weight support to NearestCentroid#33477

Open
prem-479 wants to merge 5 commits intoscikit-learn:mainfrom
prem-479:enh-nearest-centroid-weights
Open

ENH Add sample_weight support to NearestCentroid#33477
prem-479 wants to merge 5 commits intoscikit-learn:mainfrom
prem-479:enh-nearest-centroid-weights

Conversation

@prem-479
Copy link
Copy Markdown

@prem-479 prem-479 commented Mar 7, 2026

Reference Issues/PRs

Fixes #33457

What does this implement/fix? Explain your changes.

Proposed Changes
This PR introduces sample_weight support to the NearestCentroid classifier, allowing for weighted means (Euclidean) and weighted medians (Manhattan), with strict parity between dense and sparse inputs.

What was built:

  • NearestCentroid.fit(): Now accepts sample_weight and applies it correctly through all four code paths: weighted Euclidean mean (dense and sparse) and weighted median (dense Manhattan via _weighted_percentile and sparse Manhattan via a new utility).
  • Shrinkage & Priors: Successfully applies weights to empirical priors, weighted pooled variance, and the shrinkage m-parameter.
  • Sparse Utilities: Added csc_weighted_median_axis_0 and _get_weighted_median to sklearn/utils/sparsefuncs.py to handle implicit zeros correctly without spurious zero insertion.

Key guarantees and edge cases handled:

  • Uniform Parity: Uniform weights produce results mathematically identical to the unweighted path.
  • Fractional Weights: Supported seamlessly, provided sum(weights) > n_classes.
  • Zero-Weight Classes: Immediately raises a ValueError with a clear message to prevent downstream NaNs.
  • Float Arithmetic: Clamped sum_w_zeros to prevent tiny negative values caused by floating-point inaccuracies.
  • NumPy 2.0 Compatibility: Replaced deprecated np.ptp with X.max(axis=0) - X.min(axis=0).

Files Changed:

  • sklearn/neighbors/_nearest_centroid.py
  • sklearn/utils/sparsefuncs.py
  • sklearn/neighbors/tests/test_nearest_centroid.py
  • sklearn/utils/tests/test_sparsefuncs.py

Test Coverage:
Added test_csc_weighted_median_axis_0 and test_nearest_centroid_sample_weight to verify identical behavior between dense/sparse inputs, correct centroid shifting, and proper error handling. All 24 tests pass cleanly.

AI usage disclosure

I used AI assistance for:

  • Code generation (e.g., when writing an implementation or fixing a bug)
  • Test/benchmark generation
  • Documentation (including examples)
  • Research and understanding

Any other comments?

All 24 tests pass locally, including parity checks between sparse and dense formats and against zero-weight edge cases. Looking forward to any feedback!

@github-actions github-actions bot added module:neighbors module:utils CI:Linter failure The linter CI is failing on this PR labels Mar 7, 2026
@github-actions github-actions bot removed the CI:Linter failure The linter CI is failing on this PR label Mar 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

META track sample_weight support for all estimators

1 participant