[MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting#7593
[MRG + 1] FIX Validate and convert X, y and groups to ndarray before splitting#7593jnothman merged 18 commits intoscikit-learn:masterfrom
Conversation
sklearn/model_selection/_split.py
Outdated
| return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name))) | ||
|
|
||
|
|
||
| def _check_X_y_groups(X, y, groups): |
There was a problem hiding this comment.
Should this reside inside utils.validation?
There was a problem hiding this comment.
probably. Is the same applicable for sample_weights? What do we usually do with sample_weights?
We might just write check_X_y and then do a check_consistent_length(X, sample_weights) and check_array(sample_weights).
d45b75c to
ff5f379
Compare
sklearn/model_selection/_split.py
Outdated
| allow_nd=True) | ||
| check_consistent_length(X, y) | ||
| if groups is not None: | ||
| groups = check_array(groups, accept_sparse=['coo', 'csr', 'csc'], |
There was a problem hiding this comment.
groups can be infinite? and sparse? and nd? Is that tested? ;)
sklearn/model_selection/_split.py
Outdated
| dtype=None, force_all_finite=False, ensure_2d=False, | ||
| allow_nd=True) | ||
| if y is not None: | ||
| y = check_array(y, accept_sparse=['coo', 'csr', 'csc'], |
There was a problem hiding this comment.
Same for y. Are these tested? Should they be? I guess we should be as loose as possible with the test as long as the cross-validation classes work.
|
There is a test for And we could do the |
76027d5 to
578442b
Compare
|
|
||
| def test_shufflesplit_list_input(): | ||
| # Check that when y is a list / list of string labels, it works. | ||
| ss = ShuffleSplit(random_state=42) |
There was a problem hiding this comment.
shouldn't that be StratifiedShuffleSplit?
| def _iter_indices(self, X, y, groups): | ||
| if groups is None: | ||
| raise ValueError("The groups parameter should not be None") | ||
| groups = check_array(groups, ensure_2d=False, dtype=None) |
There was a problem hiding this comment.
How about GroupKFold, LeaveOneGroupOut, LeavePGroupsOut?
There was a problem hiding this comment.
Fixed... Thanks for the catch!!
|
Argh. There seemed to have been no tests for |
f117a07 to
13f1e95
Compare
| if groups is None: | ||
| raise ValueError("The groups parameter should not be None") | ||
| X, y, groups = indexable(X, y, groups) | ||
| groups = check_array(groups, ensure_2d=False, dtype=None) |
There was a problem hiding this comment.
I'd to it the other way around, I think.
There was a problem hiding this comment.
check_array followed by indexable?
amueller
left a comment
There was a problem hiding this comment.
Looks good apart from some nitpicks.
|
|
||
| for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1), | ||
| (lpgo_2, 2))): | ||
| groups = (np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]), |
There was a problem hiding this comment.
do we want these to be file-level constants?
| logo = LeaveOneGroupOut() | ||
| lpgo_1 = LeavePGroupsOut(n_groups=1) | ||
| lpgo_2 = LeavePGroupsOut(n_groups=2) | ||
| lpgo_3 = LeavePGroupsOut(n_groups=3) |
There was a problem hiding this comment.
for this one you only test the repr, right?
| [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3], | ||
| ['1', '1', '1', '1', '2', '2', '2', '3', '3', '3', '3', '3']) | ||
|
|
||
| all_n_splits = np.array([[3, 3, 3], |
There was a problem hiding this comment.
why do you hard-code it like this? that seems hard to validate. It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out) right?
There was a problem hiding this comment.
It's just scipy.misc.comb(len(np.unique(groups_i)), p_groups_out) right
That is the implementation in _split.py. I thought it would be better to compare it against hand calculated values?
There was a problem hiding this comment.
Hm. The correctness of your "hand calculated values" is not immediately obvious to me.
How about
n_groups = len(np.unique(groups_i))
n_splits = n_groups if p_groups_out == 1 else n_groups * (n_groups - 1) / 2 ?but I'm also fine leaving it like it is.
Why is all_n_splits of length 7 when groups is of length 6? (or github shows me a weird diff)
| # First test: no train group is in the test set and vice versa | ||
| grps_train_unique = np.unique(groups_arr[train]) | ||
| grps_test_unique = np.unique(groups_arr[test]) | ||
| assert_false(np.any(np.in1d(groups_arr[train], |
There was a problem hiding this comment.
why not test the intersection is empty?
assert_equal(set(groups_arr[train]).intersection(groups_arr[test]), set())
There was a problem hiding this comment.
(or intersect1d if you prefer)
There was a problem hiding this comment.
Wait that is already done in the next 2 lines...
There was a problem hiding this comment.
third tests checks whether indices are disjoint, my code checks if the groups are disjoint.
| grps_train_unique))) | ||
|
|
||
| # Second test: train and test add up to all the data | ||
| assert_equal(groups_arr[train].size + |
There was a problem hiding this comment.
len(train) + len(test) = len(groups)?
|
lgtm apart from minor comments |
13f1e95 to
fce36af
Compare
b5d1fe3 to
44f6db6
Compare
|
travis fails? |
|
Sorry about that. Should be fixed now... |
0516776 to
1ca13d1
Compare
| np.testing.assert_equal(y_train2, y_train3) | ||
| np.testing.assert_equal(X_test1, X_test3) | ||
| np.testing.assert_equal(y_test3, y_test2) | ||
| for stratify in ((y1, y2, y3), (None, None, None)): |
|
Apologies for the delay! Have rebased and added the test... Could you check if it's okay? |
|
|
||
| for stratify in ((y1, y2, y3), (None, None, None)): | ||
| X_train1, X_test1, y_train1, y_test1 = train_test_split( | ||
| X, y1, stratify=stratify[0], random_state=0) |
There was a problem hiding this comment.
I think stratify=y1 if stratify else None would be more readable (where stratify in (True, False) is iterated)
|
(Maybe we should allow |
|
Thanks for the patient review and merge! |
|
needs a whatsnew maybe? |
Fixes #7582 and #7126
At sklearn 0.18.0
That is fixed after this PR.
This PR also cleans up some docstrings and adds test for
LeavePGroupsOutandLeaveOneGroupOut...@jnothman @amueller @lesteve Reviews please :)