Skip to content

BUG SGDRegressor only using the first two samples for validation error calculation #23255

@MaxwellLZH

Description

@MaxwellLZH

Describe the bug

BaseSGD._make_validation_split returns a 0-1 integer array as validation mask, which is used to calculate the validation score with:

_ValidationScoreCallback(
            self,
            X[validation_mask],
            y[validation_mask],
            sample_weight[validation_mask],
            classes=classes,

here X[validation_mask] is just repeating the first two samples as shown in the following code. A quick fix would be convert validation_mask to a boolean array.

Steps/Code to Reproduce

I added a print statement in _ValidationScoreCallback to print out the shape of validation data

class _ValidationScoreCallback:
    """Callback for early stopping based on validation score"""

    def __init__(self, estimator, X_val, y_val, sample_weight_val, classes=None):
        self.estimator = clone(estimator)
        self.estimator.t_ = 1  # to pass check_is_fitted
        if classes is not None:
            self.estimator.classes_ = classes
        self.X_val = X_val
        self.y_val = y_val
        self.sample_weight_val = sample_weight_val
        print(X_val.shape, y_val.shape)

and afterwards running the following code

import numpy as np
from sklearn.linear_model import SGDRegressor

X = np.random.randn(1000, 5)
y = np.random.randn(1000)

sgd = SGDRegressor(early_stopping=True, max_iter=1, validation_fraction=0.1)
sgd.fit(X, y)

Expected Results

The size of validation data should be (100,5), (100, ).

Actual Results

prints out (1000, 5), (1000,)

Versions

System:
    python: 3.8.8 (default, Apr 13 2021, 12:59:45)  [Clang 10.0.0 ]
executable: /Users/mac/Documents/GitHub/.ENV/bin/python
   machine: macOS-10.16-x86_64-i386-64bit

Python dependencies:
      sklearn: 1.2.dev0
          pip: 20.2.3
   setuptools: 60.9.3
        numpy: 1.21.2
        scipy: 1.8.0
       Cython: 0.29.28
       pandas: 1.3.3
   matplotlib: 3.5.1
       joblib: 1.1.0
threadpoolctl: 3.0.0

Built with OpenMP: False

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions