@@ -61,22 +61,28 @@ def test_ovr_fit_predict():
6161def test_ovr_fit_predict_sparse ():
6262 for sparse in [sp .csr_matrix , sp .csc_matrix , sp .coo_matrix , sp .dok_matrix ,
6363 sp .lil_matrix ]:
64- # A classifier which implements decision_function.
65- ovr = OneVsRestClassifier (LinearSVC (random_state = 0 ))
66- pred = ovr .fit (iris .data ,
67- sparse (iris .target )).predict (sparse (iris .data ))
68- assert_equal (len (ovr .estimators_ ), n_classes )
69- assert_true (sp .issparse (pred ))
70-
71- clf = LinearSVC (random_state = 0 )
72- pred2 = clf .fit (iris .data , iris .target ).predict (iris .data )
73- assert_equal (np .mean (iris .target == pred .toarray ()),
74- np .mean (iris .target == pred2 ))
75-
76- # A classifier which implements predict_proba.
77- ovr = OneVsRestClassifier (MultinomialNB ())
78- pred = ovr .fit (iris .data , iris .target ).predict (iris .data )
79- assert_greater (np .mean (iris .target == pred ), 0.65 )
64+ base_clf = MultinomialNB (alpha = 1 )
65+ for au , prec , recall in zip ((True , False ), (0.65 , 0.74 ), (0.72 , 0.84 )):
66+ make_mlb = datasets .make_multilabel_classification
67+ X , Y = make_mlb (n_samples = 100 ,
68+ n_features = 20 ,
69+ n_classes = 5 ,
70+ n_labels = 2 ,
71+ length = 50 ,
72+ allow_unlabeled = au ,
73+ return_indicator = True ,
74+ random_state = 0 )
75+
76+ X_train , Y_train = X [:80 ], Y [:80 ]
77+ X_test , Y_test = X [80 :], Y [80 :]
78+ clf = OneVsRestClassifier (base_clf ).fit (X_train , sparse (Y_train ))
79+ Y_pred = clf .predict (X_test )
80+
81+ assert_true (clf .multilabel_ )
82+ assert_almost_equal (precision_score (Y_test , Y_pred .toarray (),
83+ average = "micro" ), prec , decimal = 2 )
84+ assert_almost_equal (recall_score (Y_test , Y_pred .toarray (),
85+ average = "micro" ), recall , decimal = 2 )
8086
8187
8288def test_ovr_always_present ():
0 commit comments