Skip to content

Allow GridSearchCV() to account for model variance to select .best_estimator_ #12865

@PedroUria

Description

@PedroUria

Say we have a machine learning model and we want find the best value of one of its hyperparameters out of two values. To do this, we use the cross_val_score() function with 3-fold cross-validation on our training set and get these scores:

Model Acc_1 Acc_2 Acc_3 CV_mean CV_std
A 0.71 0.7 0.69 0.7000 0.0082
B 0.91 0.7 0.5 0.7033 0.1674

where Acc_1 is the accuracy using the first two folds for training and the third for testing, CV_mean is the mean of the accuracies and CV_std is their standard deviation.

Now, I believe anyone would prefer model A over model B, at least under normal circumstances. However, when using GirdSearchCV(), the result would be .best_estimator_ = B, because CV_mean is slightly higher. That is, GridSearchCV() doesn't pay any attention to the standard deviation of such scores when calculating the .best_estimator_ instance, but if it did, then we would have surely gotten .best_estimator_ = A. I agree that usually, the model with the highest CV_mean is also the model with the lowest or one of the lowest CV_std, and even the model with the lowest overfitting on our training data, but this is not always the case, and thus selecting the model only based on the highest CV_mean may lead to a worse model.

I am also aware that GirdSearcCV() is keeping track of the standard deviations, storing them on the .cv_results_ dictionary with key std_test_score. However, I would like the option to tell GridSearchCV() to select the best model based also on its variability over the different train-test splits of the cross-validation loop, instead of having to do it myself semi-manually.

Is this something more people would be interested in? If so, I would like to try to add this feature myself. Although I have never contributed to an open source project before, and I am fairly new to programming and machine learning, I feel like this would be a good place to start.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions