-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Allow GridSearchCV() to account for model variance to select .best_estimator_ #12865
Description
Say we have a machine learning model and we want find the best value of one of its hyperparameters out of two values. To do this, we use the cross_val_score() function with 3-fold cross-validation on our training set and get these scores:
| Model | Acc_1 | Acc_2 | Acc_3 | CV_mean | CV_std |
|---|---|---|---|---|---|
| A | 0.71 | 0.7 | 0.69 | 0.7000 | 0.0082 |
| B | 0.91 | 0.7 | 0.5 | 0.7033 | 0.1674 |
where Acc_1 is the accuracy using the first two folds for training and the third for testing, CV_mean is the mean of the accuracies and CV_std is their standard deviation.
Now, I believe anyone would prefer model A over model B, at least under normal circumstances. However, when using GirdSearchCV(), the result would be .best_estimator_ = B, because CV_mean is slightly higher. That is, GridSearchCV() doesn't pay any attention to the standard deviation of such scores when calculating the .best_estimator_ instance, but if it did, then we would have surely gotten .best_estimator_ = A. I agree that usually, the model with the highest CV_mean is also the model with the lowest or one of the lowest CV_std, and even the model with the lowest overfitting on our training data, but this is not always the case, and thus selecting the model only based on the highest CV_mean may lead to a worse model.
I am also aware that GirdSearcCV() is keeping track of the standard deviations, storing them on the .cv_results_ dictionary with key std_test_score. However, I would like the option to tell GridSearchCV() to select the best model based also on its variability over the different train-test splits of the cross-validation loop, instead of having to do it myself semi-manually.
Is this something more people would be interested in? If so, I would like to try to add this feature myself. Although I have never contributed to an open source project before, and I am fairly new to programming and machine learning, I feel like this would be a good place to start.