|
7 | 7 |
|
8 | 8 | from sklearn import datasets |
9 | 9 | from sklearn import svm |
| 10 | +from sklearn import ensemble |
10 | 11 |
|
11 | 12 | from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer |
12 | 13 | from sklearn.datasets import make_multilabel_classification |
@@ -466,6 +467,29 @@ def test_roc_returns_consistency(): |
466 | 467 | assert_equal(fpr.shape, thresholds.shape) |
467 | 468 |
|
468 | 469 |
|
| 470 | +def test_roc_nonrepeating_thresholds(): |
| 471 | + """Test to ensure that we don't return spurious repeating thresholds |
| 472 | + due to machine precision issues |
| 473 | + """ |
| 474 | + dataset = datasets.load_digits() |
| 475 | + X = dataset['data'] |
| 476 | + y = dataset['target'] |
| 477 | + |
| 478 | + # This random forest classifier can only return probabilities |
| 479 | + # significant to two decimal places |
| 480 | + clf = ensemble.RandomForestClassifier(n_estimators=100, random_state=0) |
| 481 | + |
| 482 | + # How well can the classifier predict whether a digit is less than 5? |
| 483 | + # This task contributes floating point roundoff errors to the probabilities |
| 484 | + probas_pred = clf.fit(X[::2], y[::2]).predict_proba(X[1::2]) |
| 485 | + probas_pred = probas_pred[:, :5].sum(axis=1) |
| 486 | + y_true = [yy < 5 for yy in y[1::2]] |
| 487 | + |
| 488 | + # Check for repeating values in the thresholds |
| 489 | + fpr, tpr, thresholds = roc_curve(y_true, probas_pred) |
| 490 | + assert_equal(thresholds.size, np.unique(np.round(thresholds, 2)).size) |
| 491 | + |
| 492 | + |
469 | 493 | def test_roc_curve_multi(): |
470 | 494 | """roc_curve not applicable for multi-class problems""" |
471 | 495 | y_true, _, probas_pred = make_prediction(binary=False) |
|
0 commit comments