Skip to content

LinearSVC does not correctly handle sample_weight under class_weight strategy 'balanced' #30056

@snath-xoc

Description

@snath-xoc

Describe the bug

LinearSVC does not pass sample weights through when computing class weights under the "balanced" strategy leading to sample weight invariance issues cross-linked to meta-issue #16298

Steps/Code to Reproduce

from sklearn.svm import LinearSVC
from sklearn.base import clone

from sklearn.datasets import make_classification
import numpy as np

rng = np.random.RandomState()

X, y = make_classification(
    n_samples=100,
    n_features=5,
    n_informative=3,
    n_classes=4,
    random_state=0,
)

# Create dataset with repetitions and corresponding sample weights
sample_weight = rng.randint(0, 10, size=X.shape[0])
X_resampled_by_weights = np.repeat(X, sample_weight, axis=0)
y_resampled_by_weights = np.repeat(y, sample_weight)

est_sw = LinearSVC(dual=False,class_weight="balanced").fit(X, y, sample_weight=sample_weight)
est_dup = LinearSVC(dual=False,class_weight="balanced").fit(
    X_resampled_by_weights, y_resampled_by_weights, sample_weight=None
)

np.testing.assert_allclose(est_sw.coef_, est_dup.coef_,rtol=1e-10,atol=1e-10)
np.testing.assert_allclose(
    est_sw.decision_function(X_resampled_by_weights),
    est_dup.decision_function(X_resampled_by_weights),
    rtol=1e-10,
    atol=1e-10
)

Expected Results

No error thrown

Actual Results

AssertionError: 
Not equal to tolerance rtol=1e-10, atol=1e-10

Mismatched elements: 20 / 20 (100%)
Max absolute difference among violations: 0.00818953
Max relative difference among violations: 0.10657042
 ACTUAL: array([[ 0.157045, -0.399979, -0.050654,  0.236997, -0.313416],
       [-0.038369, -0.169516, -0.239528, -0.164231,  0.29698 ],
       [ 0.069654,  0.250218,  0.268922, -0.065565, -0.195888],
       [-0.117921,  0.185563,  0.005148,  0.006144,  0.130577]])
 DESIRED: array([[ 0.157595, -0.401087, -0.051018,  0.23653 , -0.313528],
       [-0.041687, -0.169006, -0.243102, -0.16373 ,  0.302628],
       [ 0.065096,  0.245549,  0.260732, -0.061577, -0.188419],
       [-0.117224,  0.184116,  0.004652,  0.005555,  0.130453]])

Versions

System:
    python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
   machine: macOS-14.3-arm64-arm-64bit

Python dependencies:
      sklearn: 1.6.dev0
          pip: 24.0
   setuptools: 70.1.1
        numpy: 2.0.0
        scipy: 1.14.0
       Cython: 3.0.10
       pandas: 2.2.2
   matplotlib: 3.9.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
...
    num_threads: 8
         prefix: libomp
       filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
        version: None
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions