Skip to content

Commit 31eafdd

Browse files
committed
Warnings related to deprecation of pos_label
1 parent a852787 commit 31eafdd

2 files changed

Lines changed: 70 additions & 2 deletions

File tree

sklearn/metrics/metrics.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels='compat',
13941394
When `labels` is None, all labels in `y_true` and `y_pred` are used in
13951395
sorted order. By default, binary classification is handled specially
13961396
for backwards compatibility, but this feature will be removed in
1397-
version 0.16.
1397+
release 0.16.
13981398
13991399
average : string, [None (default), 'micro', 'macro', 'samples', 'weighted']
14001400
If ``None``, the scores for each class are returned. Otherwise,
@@ -1471,9 +1471,19 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels='compat',
14711471
y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred)
14721472
present_labels = unique_labels(y_true, y_pred)
14731473

1474+
if pos_label != '!deprecated':
1475+
warnings.warn('The `pos_label` parameter to precision, recall and '
1476+
'F-score is deprecated, and will be removed in release '
1477+
'0.16. The `labels` parameter may be used instead.',
1478+
DeprecationWarning)
1479+
14741480
if not isinstance(labels, np.ndarray) and labels == 'compat':
14751481
if y_type == 'binary' and (average is not None and
14761482
pos_label is not None):
1483+
warnings.warn('From release 0.16, binary classification will not '
1484+
'be handled specially for precision, recall and '
1485+
'F-score. Instead, specify a single positive label '
1486+
'with the `labels` parameter.', FutureWarning)
14771487

14781488
if pos_label == '!deprecated':
14791489
pos_label = 1
@@ -1491,6 +1501,14 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels='compat',
14911501
else:
14921502
n_labels = len(labels)
14931503
labels = np.hstack([labels, np.setdiff1d(present_labels, labels)])
1504+
if n_labels == 2 and len(labels) == 2 and (pos_label is not None and
1505+
average is not None):
1506+
warnings.warn('Precision, recall and F-score behaviour has '
1507+
'changed: providing two classes to the `labels` '
1508+
'parameter no longer returns results only for the '
1509+
'positive label. Use `labels=[positive_label]` for '
1510+
'former behaviour, or `labels=None` for all labels '
1511+
'present in the data to be considered equally.')
14941512

14951513
### Calculate tp_sum, pred_sum, true_sum ###
14961514

sklearn/metrics/tests/test_metrics.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ def test_auc_score_non_binary_class():
553553
y_pred)
554554

555555

556+
@ignore_warnings
556557
def test_precision_recall_f1_score_binary():
557558
"""Test Precision Recall and F1 Score for binary classification task"""
558559
y_true, y_pred, _ = make_prediction(binary=True)
@@ -679,6 +680,7 @@ def test_average_precision_score_tied_values():
679680
assert_not_equal(average_precision_score(y_true, y_score), 1.)
680681

681682

683+
@ignore_warnings
682684
def test_precision_recall_fscore_support_errors():
683685
y_true, y_pred, _ = make_prediction(binary=True)
684686

@@ -737,7 +739,8 @@ def test_precision_recall_f1_score_multiclass():
737739
assert_array_equal(s, [24, 31, 20])
738740

739741
# averaging tests
740-
ps = precision_score(y_true, y_pred, pos_label=1, average='micro')
742+
ps = assert_warns(DeprecationWarning, precision_score,
743+
y_true, y_pred, pos_label=1, average='micro')
741744
assert_array_almost_equal(ps, 0.53, 2)
742745

743746
rs = recall_score(y_true, y_pred, average='micro')
@@ -780,6 +783,7 @@ def test_precision_recall_f1_score_multiclass():
780783
assert_array_equal(s, [24, 20, 31])
781784

782785

786+
@ignore_warnings
783787
def test_precision_recall_f1_score_multiclass_pos_label_none():
784788
"""Test Precision Recall and F1 Score for multiclass classification task
785789
@@ -1118,6 +1122,7 @@ def test_r2_one_case_error():
11181122
assert_raises(ValueError, r2_score, [0], [0])
11191123

11201124

1125+
@ignore_warnings
11211126
def test_symmetry():
11221127
"""Test the symmetry of score and loss functions"""
11231128
y_true, y_pred, _ = make_prediction(binary=True)
@@ -1155,6 +1160,7 @@ def test_symmetry():
11551160
zero_one_score(y_pred, y_true))
11561161

11571162

1163+
@ignore_warnings
11581164
def test_sample_order_invariance():
11591165
y_true, y_pred, _ = make_prediction(binary=True)
11601166

@@ -1169,6 +1175,7 @@ def test_sample_order_invariance():
11691175
% name)
11701176

11711177

1178+
@ignore_warnings
11721179
def test_format_invariance_with_1d_vectors():
11731180
y1, y2, _ = make_prediction(binary=True)
11741181

@@ -1243,6 +1250,7 @@ def test_format_invariance_with_1d_vectors():
12431250
assert_raises(ValueError, metric, y1_row, y2_row)
12441251

12451252

1253+
@ignore_warnings
12461254
def test_invariance_string_vs_numbers_labels():
12471255
"""Ensure that classification metrics with string labels"""
12481256
y1, y2, _ = make_prediction(binary=True)
@@ -1285,6 +1293,7 @@ def test_invariance_string_vs_numbers_labels():
12851293
assert_raises(ValueError, metrics, y1_str, y2_str)
12861294

12871295

1296+
@ignore_warnings
12881297
def test_clf_single_sample():
12891298
"""Non-regression test: scores should work with a single sample.
12901299
@@ -1948,6 +1957,47 @@ def test_prf_warnings():
19481957
'being set to 0.0 due to no true samples.')
19491958

19501959

1960+
def test_prf_pos_label_deprecation_warnings():
1961+
with warnings.catch_warnings(record=True) as record:
1962+
warnings.simplefilter('always')
1963+
# need deprecation warning as long as pos_label is explicitly set
1964+
recall_score([1, 2, 3, 2], [2, 2, 1, 3], pos_label=None)
1965+
assert_equal(str(record.pop().message),
1966+
'The `pos_label` parameter to precision, recall and '
1967+
'F-score is deprecated, and will be removed in release '
1968+
'0.16. The `labels` parameter may be used instead.')
1969+
recall_score([1, 2, 3, 2], [2, 2, 1, 3], pos_label=1)
1970+
assert_equal(str(record.pop().message),
1971+
'The `pos_label` parameter to precision, recall and '
1972+
'F-score is deprecated, and will be removed in release '
1973+
'0.16. The `labels` parameter may be used instead.')
1974+
1975+
# warning that default binary behaviour will be removed in the future
1976+
recall_score([1, 2, 1], [2, 2, 1], average='macro')
1977+
assert_equal(str(record.pop().message),
1978+
'From release 0.16, binary classification will not be '
1979+
'handled specially for precision, recall and F-score. '
1980+
'Instead, specify a single positive label with the '
1981+
'`labels` parameter.')
1982+
1983+
# but no warning for the follwing
1984+
recall_score([1, 2, 1], [2, 2, 1], average=None)
1985+
assert_equal(len(record), 0)
1986+
recall_score([1, 2, 1], [2, 2, 1], labels=[2], average='macro')
1987+
assert_equal(len(record), 0)
1988+
1989+
# warning that behaviour has changed when labels is specified as binary
1990+
# for binary data, with pos_label non-None and average non-None
1991+
recall_score([1, 2, 1], [2, 2, 1], labels=[1, 2], average='macro')
1992+
assert_equal(str(record.pop().message),
1993+
'Precision, recall and F-score behaviour has changed: '
1994+
'providing two classes to the `labels` parameter no '
1995+
'longer returns results only for the positive label. '
1996+
'Use `labels=[positive_label]` for former behaviour, '
1997+
'or `labels=None` for all labels present in the data '
1998+
'to be considered equally.')
1999+
2000+
19512001
def test__check_clf_targets():
19522002
"""Check that _check_clf_targets correctly merges target types, squeezes
19532003
output and fails if input lengths differ."""

0 commit comments

Comments
 (0)