-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
predict fails for multioutput ensemble models with non-numeric DVs #12831
Copy link
Copy link
Closed
Labels
Description
Description
Multioutput forest models assume that the dependent variables are numeric. Passing string DVs returns the following error:
ValueError: could not convert string to float:
I'm going to take a stab at submitting a fix today, but I wanted to file an issue to document the problem in case I'm not able to finish a fix.
Steps/Code to Reproduce
I wrote a test based on ensemble/tests/test_forest:test_multioutput which currently fails:
def check_multioutput_string(name):
# Check estimators on multi-output problems with string outputs.
X_train = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-2, 1],
[-1, 1], [-1, 2], [2, -1], [1, -1], [1, -2]]
y_train = [["red", "blue"], ["red", "blue"], ["red", "blue"], ["green", "green"],
["green", "green"], ["green", "green"], ["red", "purple"],
["red", "purple"], ["red", "purple"], ["green", "yellow"],
["green", "yellow"], ["green", "yellow"]]
X_test = [[-1, -1], [1, 1], [-1, 1], [1, -1]]
y_test = [["red", "blue"], ["green", "green"], ["red", "purple"], ["green", "yellow"]]
est = FOREST_ESTIMATORS[name](random_state=0, bootstrap=False)
y_pred = est.fit(X_train, y_train).predict(X_test)
assert_array_almost_equal(y_pred, y_test)
if name in FOREST_CLASSIFIERS:
with np.errstate(divide="ignore"):
proba = est.predict_proba(X_test)
assert_equal(len(proba), 2)
assert_equal(proba[0].shape, (4, 2))
assert_equal(proba[1].shape, (4, 4))
log_proba = est.predict_log_proba(X_test)
assert_equal(len(log_proba), 2)
assert_equal(log_proba[0].shape, (4, 2))
assert_equal(log_proba[1].shape, (4, 4))
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS)
def test_multioutput_string(name):
check_multioutput_string(name)
Expected Results
No error is thrown, can run predict for all ensemble multioutput models
Actual Results
ValueError: could not convert string to float: <DV class>
Versions
I replicated this error using the current master branch of sklearn (0.21.dev0).
Reactions are currently unavailable