[MRG] Fixes tree and forest classification for non-numeric multi-target#11458
[MRG] Fixes tree and forest classification for non-numeric multi-target#11458adrinjalali merged 5 commits intoscikit-learn:masterfrom
Conversation
|
Any update on this PR? |
| @pytest.mark.parametrize('name', FOREST_CLASSIFIERS) | ||
| @pytest.mark.parametrize('oob_score', (True, False)) | ||
| def test_multi_target(name, oob_score): | ||
| check_multi_target(name, oob_score) |
There was a problem hiding this comment.
any reason for not having the body of the check_multi_target function directly here?
sklearn/tree/tests/test_tree.py
Outdated
|
|
||
| @pytest.mark.parametrize('name', CLF_TREES) | ||
| def test_multi_target(name): | ||
| check_multi_target(name) |
sklearn/tree/tree.py
Outdated
| axis=0)) | ||
|
|
||
| return predictions | ||
| return np.array(predictions).T |
There was a problem hiding this comment.
this would try to figure the dtype of the array, right? how much is it slower than the status quo?
|
You also need to rebase/merge master, you've got conflicts. Other than that, I'm really not sure if this is a good idea. How many other estimators do we have that support string outputs? I suppose the recommended way is to convert the values before feeding them to estimators. I may be wrong. |
|
We support string targets where 1d (i.e. single target). I'm not entirely against supporting strong labels in multi output, but it should be by making sure that all estimators with multi output multiclass support, and any metrics, support this case. Let alone the case of mixed numeric and string data. At the moment I can't see that we test multi output multiclass in common tests at all. |
jnothman
left a comment
There was a problem hiding this comment.
What does your implementation do with a mix of string and numeric targets?
e215890 to
b26e943
Compare
5c5ecff to
fcd597a
Compare
|
I've refactored the tests and the code a bit. Also PR is rebased to master.
Returns: Doing the same with a regressor would fail as targets need to be numerical. Training target array is upcast to one dtype and we will get an array with the same dtype back from A way to support a real mix of dtypes would be with structured array, but I don't know if we really want to do that? |
| # Make multi-target. | ||
| ys = np.hstack([y, y]) | ||
|
|
||
| # Try to fix and predict. |
| y = np.array(['foo' if v else 'bar' for v in y]).reshape((y.shape[0], 1)) | ||
|
|
||
| # Make multi-target. | ||
| ys = np.hstack([y, y]) |
There was a problem hiding this comment.
try with a string and a numerical column just to be on the safe side in the test?
|
This looks good to me. Since Joel mentioned it, could you please kindly try adding the same test on common tests ( |
|
Thaks for the quick responese @adrinjalali! |
jnothman
left a comment
There was a problem hiding this comment.
I would support doing common tests in a separate PR, ideally remembering to remove these tests as redundant.
|
Thank @adrinjalali, @mitar and @jnothman. |
…scikit-learn#11458) * Fixes tree and forest classification for non-numeric multi-target. Fixes scikit-learn#11451. * Renaming test functions, adding dtype to predictions array in tree.py. * Fixing flake8 issue. * Adding ignore warning to test_forest.py. * Switching to iris data for tests.
…i-target (scikit-learn#11458)" This reverts commit f95ffe6.
…i-target (scikit-learn#11458)" This reverts commit f95ffe6.
…scikit-learn#11458) * Fixes tree and forest classification for non-numeric multi-target. Fixes scikit-learn#11451. * Renaming test functions, adding dtype to predictions array in tree.py. * Fixing flake8 issue. * Adding ignore warning to test_forest.py. * Switching to iris data for tests.
Fixes #11451.
This fixes the issue that trees and forests cannot classify (but they can fit) non-numeric targets, when there are multiple targets.