Skip to content

Commit 84cac35

Browse files
committed
Added a unit test to ensure that there are no spurious repeating values in the thresholds returned by roc_curve because of machine precision, and a quick stab at a fix.
1 parent dcf48a5 commit 84cac35

2 files changed

Lines changed: 26 additions & 1 deletion

File tree

sklearn/metrics/metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,8 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
725725
# y_score typically has many tied values. Here we extract
726726
# the indices associated with the distinct values. We also
727727
# concatenate a value for the end of the curve.
728-
distinct_value_indices = np.where(np.diff(y_score))[0]
728+
y_round = np.round(y_score, 6) # a million thresholds should be enough?
729+
distinct_value_indices = np.where(np.diff(y_round))[0]
729730
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
730731

731732
# accumulate the true positives with decreasing threshold

sklearn/metrics/tests/test_metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from sklearn import datasets
99
from sklearn import svm
10+
from sklearn import ensemble
1011

1112
from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer
1213
from sklearn.datasets import make_multilabel_classification
@@ -466,6 +467,29 @@ def test_roc_returns_consistency():
466467
assert_equal(fpr.shape, thresholds.shape)
467468

468469

470+
def test_roc_nonrepeating_thresholds():
471+
"""Test to ensure that we don't return spurious repeating thresholds
472+
due to machine precision issues
473+
"""
474+
dataset = datasets.load_digits()
475+
X = dataset['data']
476+
y = dataset['target']
477+
478+
# This random forest classifier can only return probabilities
479+
# significant to two decimal places
480+
clf = ensemble.RandomForestClassifier(n_estimators=100, random_state=0)
481+
482+
# How well can the classifier predict whether a digit is less than 5?
483+
# This task contributes floating point roundoff errors to the probabilities
484+
probas_pred = clf.fit(X[::2], y[::2]).predict_proba(X[1::2])
485+
probas_pred = probas_pred[:, :5].sum(axis=1)
486+
y_true = [yy < 5 for yy in y[1::2]]
487+
488+
# Check for repeating values in the thresholds
489+
fpr, tpr, thresholds = roc_curve(y_true, probas_pred)
490+
assert_equal(thresholds.size, np.unique(np.round(thresholds, 2)).size)
491+
492+
469493
def test_roc_curve_multi():
470494
"""roc_curve not applicable for multi-class problems"""
471495
y_true, _, probas_pred = make_prediction(binary=False)

0 commit comments

Comments
 (0)