-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Suggestion: Remove prediction from plot_confusion_matrix and just pass predicted labels #15880
Description
The signature of plot_confusion_matrix is currently:
sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None)
The function takes an estimator and raw data and can not be used with already predicted labels. This has some downsides:
- If a confusion matrix should be plotted but the predictions should also be used elsewhere (e.g. calculating accuracy_score) the estimation has to be performed several times. That takes longer and can result in different values if the estimator is randomized.
- If no estimator is available (e.g. predictions loaded from a file) the plot can not be used at all.
Suggestion: allow passing predicted labels y_pred to plot_confusion_matrix that will be used instead of estimator and X. In my opinion the cleanest solution would be to remove the prediction step from the function and use a signature similar to that of accuracy_score, e.g. (y_true, y_pred, labels=None, sample_weight=None, ...). However in order to maintain backwards compatibility, y_pred can be added as an optional keyword argument.
TODO:
-
Introduce the class methods for the currently existing plots:
-
ConfusionMatrixDisplayENH/DEP add class method and deprecate plot function for confusion matrix #18543 -
PrecisionRecallDisplayAPI add from_estimator and from_preditions to PrecisionRecallDisplay #20552 -
RocCurveDisplayAPI add from_estimator and from_predictions to RocCurveDisplay #20569 -
DetCurveDisplayAPI deprecate plot_det_curve in favor of display class methods #19278 -
PartialDependenceDisplay. For this one, we don't want to introduce thefrom_predictionsclassmethod because it would not make sense, we only wantfrom_estimator.
-
-
For all Display listed above, deprecate their corresponding
plot_...function. We don't need to deprecateplot_det_curvebecause it hasn't been released yet, we can just remove it. -
for new PRs like ENH Add CalibrationDisplay plotting class #17443 and FEA add PredictionErrorDisplay #18020 we can implement the class methods right away instead of introducing a
plotfunction.