Skip to content

TST Add common check for classiffiers / regressors with float32#16359

Merged
rth merged 3 commits intoscikit-learn:masterfrom
rth:check-mmap-32bit
Feb 2, 2020
Merged

TST Add common check for classiffiers / regressors with float32#16359
rth merged 3 commits intoscikit-learn:masterfrom
rth:check-mmap-32bit

Conversation

@rth
Copy link
Copy Markdown
Member

@rth rth commented Feb 1, 2020

A lot of tree estimators were not working with read-only float32 input, as this was not checked in common tests (cf #15851). The underlying issue will be fixed in #16331 however it would be still good to have the corresponding common checks which are added in this PR.

For instance, with this PR, running,

pytest sklearn/tests/test_common.py -k "check_classifiers_train or check_regressors_train" -r f

produces the following report,

FAILED sklearn/tests/test_common.py::test_estimators[AdaBoostClassifier()-check_classifiers_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[AdaBoostRegressor()-check_regressors_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[DecisionTreeClassifier()-check_classifiers_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[DecisionTreeRegressor()-check_regressors_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[ExtraTreeClassifier()-check_classifiers_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[ExtraTreeRegressor()-check_regressors_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[ExtraTreesClassifier()-check_classifiers_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[ExtraTreesRegressor()-check_regressors_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[GradientBoostingClassifier()-check_classifiers_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[GradientBoostingRegressor()-check_regressors_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[RandomForestClassifier()-check_classifiers_train(readonly_memmap=TrueX_dtype=float32)]
FAILED sklearn/tests/test_common.py::test_estimators[RandomForestRegressor()-check_regressors_train(readonly_memmap=TrueX_dtype=float32)]

All of these will pass with #16331 merged.

Copy link
Copy Markdown
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.


kwstring = "".join(["{}={}".format(k, v)
for k, v in obj.keywords.items()])
kwstring = ",".join(["{}={}".format(k, v)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kwstring = ",".join(["{}={}".format(k, v)
kwstring = ", ".join(["{}={}".format(k, v)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a slight preference for no space here, because this would be used for selecting tests when running pytest (on the command line)

test_estimators[AdaBoostClassifier()-check_classifiers_train(readonly_memmap=True,X_dtype=float32)]

or (with space)

test_estimators[AdaBoostClassifier()-check_classifiers_train(readonly_memmap=True, X_dtype=float32)]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider it initially, but I also find that without the space the naming of tests in pytest output is better.

Thanks for the reviews!


kwstring = "".join(["{}={}".format(k, v)
for k, v in obj.keywords.items()])
kwstring = ",".join(["{}={}".format(k, v)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a slight preference for no space here, because this would be used for selecting tests when running pytest (on the command line)

test_estimators[AdaBoostClassifier()-check_classifiers_train(readonly_memmap=True,X_dtype=float32)]

or (with space)

test_estimators[AdaBoostClassifier()-check_classifiers_train(readonly_memmap=True, X_dtype=float32)]

@rth rth changed the title TST Add common check for classiffiers and regressors with 32bit read-only input TST Add common check for classiffiers / regressors with float32 Feb 2, 2020
@rth rth merged commit c91f0c9 into scikit-learn:master Feb 2, 2020
@rth rth deleted the check-mmap-32bit branch February 2, 2020 10:52
thomasjpfan pushed a commit to thomasjpfan/scikit-learn that referenced this pull request Feb 22, 2020
panpiort8 pushed a commit to panpiort8/scikit-learn that referenced this pull request Mar 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants