-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
Stratified train_test_split with binarized labels mixes up train and test data! #8670
Description
Hi scikit-learners,
I have a problem using the train_test_split with binarized labels in a multilabel setting. Specifically, I tried to use the stratify parameter to even out the data between the splits. I know (now) that stratified sampling in the multilabel setting is a tricky issue.
In my code I had encoded the labels with the MultiLabelBinarizer before calling the train_test_splitand it didn't complain but returned the two sets of features and labels. However, it turns out that it has now mixed up features and labels, so that data points present in the training set is also present in the test set. And it had also copied data within each set.
Training a classifier on this data obviously gave all too optimistic figures.
I have made a simple example that illustrates the problem:
from sklearn.model_selection import train_test_split
xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
ys_bin = [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0]
ys_multiclass = [0, 0, 1, 1, 2, 2, 3, 3, 1, 0, 0, 0, 1, 1, 2, 2, 3, 3, 1, 0]
ys_multiclass_bin = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1],
[0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0],
[0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0],
[0, 1, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]]
ys_multilabel = [[2], [1, 2], [0, 2], [2], [1, 2],
[1], [0, 1], [0], [0, 2], [1],
[2], [1, 2], [0, 2], [2], [1, 2],
[1], [0, 1], [0], [0, 2], [1]]
ys_multilabel_bin = [[0, 0, 1], [0, 1, 1], [1, 0, 1], [0, 0, 1], [0, 1, 1],
[0, 1, 0], [1, 1, 0], [1, 0, 0], [1, 0, 1], [0, 1, 0],
[0, 0, 1], [0, 1, 1], [1, 0, 1], [0, 0, 1], [0, 1, 1],
[0, 1, 0], [1, 1, 0], [1, 0, 0], [1, 0, 1], [0, 1, 0]]
# binary
x_train, x_test, y_train, y_test = train_test_split(xs, ys_bin, train_size=0.5, stratify=ys_bin)
print("Binary stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()
# multiclass
x_train, x_test, y_train, y_test = train_test_split(xs, ys_multiclass, train_size=0.5,
stratify=ys_multiclass)
print("Multiclass stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()
# multiclass binary
x_train, x_test, y_train, y_test = train_test_split(xs, ys_multiclass_bin, train_size=0.5,
stratify=ys_multiclass_bin)
print("Multiclass binary stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()
# multilabel binary
x_train, x_test, y_train, y_test = train_test_split(xs, ys_multilabel_bin, train_size=0.5,
stratify=ys_multilabel_bin)
print("Multilabel binary stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()And here is an example of the the output:
Binary stratification:
training:
[15, 9, 3, 8, 11, 18, 5, 20, 7, 10]
[0, 1, 0, 1, 0, 1, 0, 0, 1, 0]
test:
[1, 14, 2, 19, 17, 6, 12, 4, 13, 16]
[0, 0, 0, 1, 1, 1, 0, 0, 0, 1]
overlapping:
set()
Multiclass stratification:
training:
[2, 9, 11, 14, 19, 18, 10, 6, 8, 15]
[0, 1, 0, 1, 1, 3, 0, 2, 3, 2]
test:
[12, 4, 7, 16, 20, 13, 3, 17, 1, 5]
[0, 1, 3, 2, 0, 1, 1, 3, 0, 2]
overlapping:
set()
Multiclass binary stratification:
training:
[6, 10, 1, 5, 9, 9, 4, 4, 3, 3]
[[0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [1, 0, 0]]
test:
[20, 8, 3, 5, 18, 2, 7, 16, 20, 11]
[[0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1]]
overlapping:
{3, 5}
Multilabel binary stratification:
training:
[13, 1, 5, 17, 20, 6, 7, 16, 14, 15]
[[1, 0, 1], [0, 0, 1], [0, 1, 1], [1, 1, 0], [0, 1, 0], [0, 1, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [0, 1, 1]]
test:
[9, 2, 16, 20, 1, 13, 5, 18, 4, 8]
[[1, 0, 1], [0, 1, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
overlapping:
{16, 1, 5, 20, 13}
One could argue that I should have checked whether stratification would work in the multilabel setting, but I think it would be better if the train_test_split fails or otherwise warns me that I am doing stupid things... :)
And it definitely shouldn't mix up training and test data in any case.
As you can see the problem is also present in one-hot encoded multiclass labels, so the error can occur if you happen to use the LabelBinarizer before splitting or if your data is already one-hot encoded before splitting it up.
Best regards,
Andreas