-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Closed
Labels
Description
Describe the bug
LinearSVC does not pass sample weights through when computing class weights under the "balanced" strategy leading to sample weight invariance issues cross-linked to meta-issue #16298
Steps/Code to Reproduce
from sklearn.svm import LinearSVC
from sklearn.base import clone
from sklearn.datasets import make_classification
import numpy as np
rng = np.random.RandomState()
X, y = make_classification(
n_samples=100,
n_features=5,
n_informative=3,
n_classes=4,
random_state=0,
)
# Create dataset with repetitions and corresponding sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])
X_resampled_by_weights = np.repeat(X, sample_weight, axis=0)
y_resampled_by_weights = np.repeat(y, sample_weight)
est_sw = LinearSVC(dual=False,class_weight="balanced").fit(X, y, sample_weight=sample_weight)
est_dup = LinearSVC(dual=False,class_weight="balanced").fit(
X_resampled_by_weights, y_resampled_by_weights, sample_weight=None
)
np.testing.assert_allclose(est_sw.coef_, est_dup.coef_,rtol=1e-10,atol=1e-10)
np.testing.assert_allclose(
est_sw.decision_function(X_resampled_by_weights),
est_dup.decision_function(X_resampled_by_weights),
rtol=1e-10,
atol=1e-10
)Expected Results
No error thrown
Actual Results
AssertionError:
Not equal to tolerance rtol=1e-10, atol=1e-10
Mismatched elements: 20 / 20 (100%)
Max absolute difference among violations: 0.00818953
Max relative difference among violations: 0.10657042
ACTUAL: array([[ 0.157045, -0.399979, -0.050654, 0.236997, -0.313416],
[-0.038369, -0.169516, -0.239528, -0.164231, 0.29698 ],
[ 0.069654, 0.250218, 0.268922, -0.065565, -0.195888],
[-0.117921, 0.185563, 0.005148, 0.006144, 0.130577]])
DESIRED: array([[ 0.157595, -0.401087, -0.051018, 0.23653 , -0.313528],
[-0.041687, -0.169006, -0.243102, -0.16373 , 0.302628],
[ 0.065096, 0.245549, 0.260732, -0.061577, -0.188419],
[-0.117224, 0.184116, 0.004652, 0.005555, 0.130453]])
Versions
System:
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
machine: macOS-14.3-arm64-arm-64bit
Python dependencies:
sklearn: 1.6.dev0
pip: 24.0
setuptools: 70.1.1
numpy: 2.0.0
scipy: 1.14.0
Cython: 3.0.10
pandas: 2.2.2
matplotlib: 3.9.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
...
num_threads: 8
prefix: libomp
filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
version: None
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...Reactions are currently unavailable