Skip to content

Stratified train_test_split with binarized labels mixes up train and test data! #8670

@andreas-hjortgaard

Description

@andreas-hjortgaard

Hi scikit-learners,

I have a problem using the train_test_split with binarized labels in a multilabel setting. Specifically, I tried to use the stratify parameter to even out the data between the splits. I know (now) that stratified sampling in the multilabel setting is a tricky issue.

In my code I had encoded the labels with the MultiLabelBinarizer before calling the train_test_splitand it didn't complain but returned the two sets of features and labels. However, it turns out that it has now mixed up features and labels, so that data points present in the training set is also present in the test set. And it had also copied data within each set.

Training a classifier on this data obviously gave all too optimistic figures.

I have made a simple example that illustrates the problem:

from sklearn.model_selection import train_test_split


xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
ys_bin = [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0]
ys_multiclass = [0, 0, 1, 1, 2, 2, 3, 3, 1, 0, 0, 0, 1, 1, 2, 2, 3, 3, 1, 0]
ys_multiclass_bin = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1],
                     [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0],
                     [0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0],
                     [0, 1, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]]

ys_multilabel = [[2], [1, 2], [0, 2], [2], [1, 2],
                 [1], [0, 1], [0], [0, 2], [1],
                 [2], [1, 2], [0, 2], [2], [1, 2],
                 [1], [0, 1], [0], [0, 2], [1]]

ys_multilabel_bin = [[0, 0, 1], [0, 1, 1], [1, 0, 1], [0, 0, 1], [0, 1, 1],
                     [0, 1, 0], [1, 1, 0], [1, 0, 0], [1, 0, 1], [0, 1, 0],
                     [0, 0, 1], [0, 1, 1], [1, 0, 1], [0, 0, 1], [0, 1, 1],
                     [0, 1, 0], [1, 1, 0], [1, 0, 0], [1, 0, 1], [0, 1, 0]]

# binary
x_train, x_test, y_train, y_test = train_test_split(xs, ys_bin, train_size=0.5, stratify=ys_bin)
print("Binary stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()

# multiclass
x_train, x_test, y_train, y_test = train_test_split(xs, ys_multiclass, train_size=0.5,
                                                    stratify=ys_multiclass)
print("Multiclass stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()

# multiclass binary
x_train, x_test, y_train, y_test = train_test_split(xs, ys_multiclass_bin, train_size=0.5,
                                                    stratify=ys_multiclass_bin)
print("Multiclass binary stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()

# multilabel binary
x_train, x_test, y_train, y_test = train_test_split(xs, ys_multilabel_bin, train_size=0.5,
                                                    stratify=ys_multilabel_bin)
print("Multilabel binary stratification:")
print("training:")
print(x_train)
print(y_train)
print("test:")
print(x_test)
print(y_test)
print("overlapping:")
print(set(x_train).intersection(x_test))
print()

And here is an example of the the output:

Binary stratification:
training:
[15, 9, 3, 8, 11, 18, 5, 20, 7, 10]
[0, 1, 0, 1, 0, 1, 0, 0, 1, 0]
test:
[1, 14, 2, 19, 17, 6, 12, 4, 13, 16]
[0, 0, 0, 1, 1, 1, 0, 0, 0, 1]
overlapping:
set()

Multiclass stratification:
training:
[2, 9, 11, 14, 19, 18, 10, 6, 8, 15]
[0, 1, 0, 1, 1, 3, 0, 2, 3, 2]
test:
[12, 4, 7, 16, 20, 13, 3, 17, 1, 5]
[0, 1, 3, 2, 0, 1, 1, 3, 0, 2]
overlapping:
set()

Multiclass binary stratification:
training:
[6, 10, 1, 5, 9, 9, 4, 4, 3, 3]
[[0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [1, 0, 0]]
test:
[20, 8, 3, 5, 18, 2, 7, 16, 20, 11]
[[0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1]]
overlapping:
{3, 5}

Multilabel binary stratification:
training:
[13, 1, 5, 17, 20, 6, 7, 16, 14, 15]
[[1, 0, 1], [0, 0, 1], [0, 1, 1], [1, 1, 0], [0, 1, 0], [0, 1, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [0, 1, 1]]
test:
[9, 2, 16, 20, 1, 13, 5, 18, 4, 8]
[[1, 0, 1], [0, 1, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
overlapping:
{16, 1, 5, 20, 13}

One could argue that I should have checked whether stratification would work in the multilabel setting, but I think it would be better if the train_test_split fails or otherwise warns me that I am doing stupid things... :)
And it definitely shouldn't mix up training and test data in any case.

As you can see the problem is also present in one-hot encoded multiclass labels, so the error can occur if you happen to use the LabelBinarizer before splitting or if your data is already one-hot encoded before splitting it up.

Best regards,
Andreas

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions