Skip to content

StratifiedShuffleSplit generates overlapping train and test indices for multilabel data #9037

@ehuijzer

Description

@ehuijzer

Description

StratifiedShuffleSplit generates train and test indices which overlap when based on multilabel data.

Steps/Code to Reproduce

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils.testing import assert_equal

X = [['A', 'B'], ['C', 'B'], ['A', 'A'], ['A', 'B']]
y = [['0', '1'], ['1', '1'], ['1', '1'], ['0', '1']]

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=3)

for train_index, test_index in sss.split(X, y):
    train = set(train_index)
    test = set(test_index)
    assert_equal(train.intersection(test), set())

Expected Results

StratifiedShuffleSplit should either handle multilabel data correctly resulting in non-onverlapping train and test indices or present an error message indicating that multilabel is not supported.

Actual Results

In case of multilabel data, StratifiedShuffleSplit generates overlapping train and test indices without any warning or error message.

Versions

Windows-10-10.0.14393-SP0
Python 3.5.2 |Continuum Analytics, Inc.| (default, Jul 5 2016, 11:41:13) [MSC v.1900 64 bit (AMD64)]
NumPy 1.11.3
SciPy 0.18.1
Scikit-Learn 0.18.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions