Skip to content

Commit fde9212

Browse files
authored
ENH add a parameter pos_label in roc_auc_score (scikit-learn#17594)
1 parent c889845 commit fde9212

4 files changed

Lines changed: 73 additions & 9 deletions

File tree

doc/whats_new/v0.24.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,12 @@ Changelog
114114
:pr:`17309` by :user:`Swier Heeres <swierh>`
115115

116116
- |Enhancement| Add `sample_weight` parameter to
117-
:class:`metrics.median_absolute_error`. :pr:`17225` by
117+
:func:`metrics.median_absolute_error`. :pr:`17225` by
118118
:user:`Lucy Liu <lucyleeow>`.
119119

120+
- |Enhancement| Add `pos_label` parameter to :func:`roc_auc_score`.
121+
:pr:`17594` by :user:`Guillaume Lemaitre <glemaitre>`.
122+
120123
:mod:`sklearn.model_selection`
121124
..............................
122125

sklearn/metrics/_ranking.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,16 @@ def _binary_uninterpolated_average_precision(
218218
average, sample_weight=sample_weight)
219219

220220

221-
def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):
221+
def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None,
222+
pos_label=None):
222223
"""Binary roc auc score"""
223224
if len(np.unique(y_true)) != 2:
224225
raise ValueError("Only one class present in y_true. ROC AUC score "
225226
"is not defined in that case.")
226227

227-
fpr, tpr, _ = roc_curve(y_true, y_score,
228-
sample_weight=sample_weight)
228+
fpr, tpr, _ = roc_curve(
229+
y_true, y_score, sample_weight=sample_weight, pos_label=pos_label,
230+
)
229231
if max_fpr is None or max_fpr == 1:
230232
return auc(fpr, tpr)
231233
if max_fpr <= 0 or max_fpr > 1:
@@ -248,7 +250,8 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):
248250

249251
@_deprecate_positional_args
250252
def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
251-
max_fpr=None, multi_class="raise", labels=None):
253+
max_fpr=None, multi_class="raise", labels=None,
254+
pos_label=None):
252255
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
253256
from prediction scores.
254257
@@ -327,6 +330,13 @@ def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
327330
If ``None``, the numerical or lexicographical order of the labels in
328331
``y_true`` is used.
329332
333+
pos_label : int or str, default=None
334+
The label of the positive class in the binary case. When
335+
`pos_label=None`, if `y_true` is in {-1, 1} or {0, 1}, `pos_label` is
336+
set to 1, otherwise an error will be raised.
337+
338+
.. versionadded:: 0.24
339+
330340
Returns
331341
-------
332342
auc : float
@@ -388,10 +398,9 @@ def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
388398
return _multiclass_roc_auc_score(y_true, y_score, labels,
389399
multi_class, average, sample_weight)
390400
elif y_type == "binary":
391-
labels = np.unique(y_true)
392-
y_true = label_binarize(y_true, classes=labels)[:, 0]
393401
return _average_binary_score(partial(_binary_roc_auc_score,
394-
max_fpr=max_fpr),
402+
max_fpr=max_fpr,
403+
pos_label=pos_label),
395404
y_true, y_score, average,
396405
sample_weight=sample_weight)
397406
else: # multilabel-indicator

sklearn/metrics/tests/test_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,17 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
319319
# Metrics with a "pos_label" argument
320320
METRICS_WITH_POS_LABEL = {
321321
"roc_curve",
322+
323+
"roc_auc_score",
324+
"weighted_roc_auc",
325+
"samples_roc_auc",
326+
"micro_roc_auc",
327+
"ovr_roc_auc",
328+
"weighted_ovr_roc_auc",
329+
"ovo_roc_auc",
330+
"weighted_ovo_roc_auc",
331+
"partial_roc_auc",
332+
322333
"precision_recall_curve",
323334

324335
"brier_score_loss",

sklearn/metrics/tests/test_ranking.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
from sklearn import datasets
88
from sklearn import svm
99

10-
from sklearn.utils.extmath import softmax
1110
from sklearn.datasets import make_multilabel_classification
11+
from sklearn.datasets import load_breast_cancer
12+
from sklearn.linear_model import LogisticRegression
13+
from sklearn.model_selection import train_test_split
1214
from sklearn.random_projection import _sparse_random_matrix
15+
from sklearn.utils import shuffle
16+
from sklearn.utils.extmath import softmax
1317
from sklearn.utils.validation import check_array, check_consistent_length
1418
from sklearn.utils.validation import check_random_state
1519

@@ -1469,3 +1473,40 @@ def test_partial_roc_auc_score():
14691473
assert_almost_equal(
14701474
roc_auc_score(y_true, y_pred, max_fpr=max_fpr),
14711475
_partial_roc_auc_score(y_true, y_pred, max_fpr))
1476+
1477+
1478+
@pytest.mark.parametrize(
1479+
"decision_method", ["predict_proba", "decision_function"]
1480+
)
1481+
def test_roc_auc_score_pos_label(decision_method):
1482+
X, y = load_breast_cancer(return_X_y=True)
1483+
# create an highly imbalanced
1484+
idx_positive = np.flatnonzero(y == 1)
1485+
idx_negative = np.flatnonzero(y == 0)
1486+
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
1487+
X, y = X[idx_selected], y[idx_selected]
1488+
X, y = shuffle(X, y, random_state=42)
1489+
# only use 2 features to make the problem even harder
1490+
X = X[:, :2]
1491+
y = np.array(
1492+
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
1493+
)
1494+
X_train, X_test, y_train, y_test = train_test_split(
1495+
X, y, stratify=y, random_state=0,
1496+
)
1497+
1498+
classifier = LogisticRegression()
1499+
classifier.fit(X_train, y_train)
1500+
1501+
# sanity check to be sure the positive class is classes_[0] and that we
1502+
# are betrayed by the class imbalance
1503+
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
1504+
pos_label = "cancer"
1505+
1506+
y_pred = getattr(classifier, decision_method)(X_test)
1507+
y_pred = y_pred[:, 0] if y_pred.ndim == 2 else -y_pred
1508+
1509+
fpr, tpr, _ = roc_curve(y_test, y_pred, pos_label=pos_label)
1510+
roc_auc = roc_auc_score(y_test, y_pred, pos_label=pos_label)
1511+
1512+
assert roc_auc == pytest.approx(np.trapz(tpr, fpr))

0 commit comments

Comments
 (0)