[MRG] simplify check_is_fitted to use any fitted attributes#14545
[MRG] simplify check_is_fitted to use any fitted attributes#14545glemaitre merged 23 commits intoscikit-learn:masterfrom
Conversation
sklearn/utils/validation.py
Outdated
|
|
||
|
|
||
| def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): | ||
| def check_is_fitted(estimator, *, msg=None): |
There was a problem hiding this comment.
aren't you changing the signature here? did you mean
check_is_fitted(estimator, *args, msg=None): to preserve backward compatibility?
There was a problem hiding this comment.
I'm changing the signature. We could have args here if we consider this public, which maybe we should?
There was a problem hiding this comment.
ok actually using args doesn't work unless I also add *kwargs. So if we want backward-compatibility we need to just do a usual deprecation cycle.
There was a problem hiding this comment.
Yes I think we consider the validation utils public. But I'm happy to see them go private.
sklearn/utils/validation.py
Outdated
|
|
||
| if not isinstance(attributes, (list, tuple)): | ||
| attributes = [attributes] | ||
| attrs = [v for v in vars(estimator) if v.endswith("_") |
There was a problem hiding this comment.
I think NearedtNeighbors has stored only _fit_X
There was a problem hiding this comment.
It has this:
scikit-learn/sklearn/neighbors/base.py
Lines 166 to 169 in 7c60ead
and it was very recently documented.
There was a problem hiding this comment.
Common tests pass so it must work ;)
|
hm vectorizer were not caught by common tests of course :-/ |
|
And |
|
See #14559, but should be passing now. This is not the cleanest work-around but that's mostly because |
|
I'm ambivalent about adding a whatsnew but I can do it if you think it's worth it. Probably should add a |
I would say that a note in the |
sklearn/utils/validation.py
Outdated
|
|
||
|
|
||
| def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): | ||
| def check_is_fitted(estimator, attributes='deprecated', msg=None): |
There was a problem hiding this comment.
The behavior is already not back-compatible. Since we mention in the documentation that these utils can change from a version to another, I would not bother with a deprecation warning for the attributes parameters knowing that one can have some side-effect with all_or_any.
There was a problem hiding this comment.
How is it not backward compatible? Oh I could deprecate all_or_any as well?
There was a problem hiding this comment.
If somebody is using all_or_any now, nothing would happen and this is not an attribute of the function as well. But as I mentioned, we clearly state in the documentation that utils are not following the deprecation cycle and can change: https://scikit-learn.org/stable/developers/utilities.html
There was a problem hiding this comment.
True, only deprecating one doesn't make sense. But also see the discussion at #6616. Basically, the docs say that but people ignore it and it might not be good if we enforce it and should make things private instead.
There was a problem hiding this comment.
OK I see. I am sure I was one of these people that complain at least once (then @lesteve show me the red box :))
I really feel that having the utils private could help to move quickly sometimes and help third-party project (at the cost of potential breaking if they use them). So deprecation it is :)
There was a problem hiding this comment.
so should I add one for all_or_any then?
|
What is the behaviour expected on from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.utils.validation import check_is_fitted
X, y = load_iris(return_X_y=True)
pipe = make_pipeline(StandardScaler(), LogisticRegression())
pipe.fit(X, y)
check_is_fitted(pipe)---------------------------------------------------------------------------
NotFittedError Traceback (most recent call last)
/tmp/tmp.py in <module>
8 pipe = make_pipeline(StandardScaler(), LogisticRegression())
9 pipe.fit(X, y)
---> 10 check_is_fitted(pipe)
~/Documents/code/toolbox/scikit-learn/sklearn/utils/validation.py in check_is_fitted(estimator, msg)
910
911 if not len(attrs):
--> 912 raise NotFittedError(msg % {'name': type(estimator).__name__})
913
914
NotFittedError: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using this method. |
|
we should probably make a recursive call on each element of the Pipeline instance then? |
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
…to anything_fitted
|
Do you still want to make a deprecation? |
|
Since we clearly state that the utils are not guaranteed to be stable, I would prefer not go through a deprecation cycle. |
|
@thomasjpfan I would say that remark's there mostly to limit liability ;) see my remarks above. I think I'll edit to also deprecate the |
glemaitre
left a comment
There was a problem hiding this comment.
If you could add at least the suggestion in the docstring. It could be easier to find it for removal.
Otherwise LGTM
| assert check_is_fitted(ard) is None | ||
| assert check_is_fitted(svr) is None | ||
|
|
||
| assert_warns_message( |
There was a problem hiding this comment.
@pytest.mark.parametrize("params", [{'attributes': ['coefs_']}, {all_or_any=any}]
def test_check_is_fitted_deprecation(params):
# FIXME: to be removed in 0.23
warn_msg = 'Passing {} to check_is_fitted is deprecated'.format(list(params.keys())[0])
with pytest.warns(DeprecationWarning, match=warn_msg):
check_is_fitted(ard, **params)It could be handy to have a separated test function to be removed next version.
We might use pytest (but the test will be removed anyway).
There was a problem hiding this comment.
Is it easier to remove a test than to remove the asserts? A comment might be nice but it will also just fail and so we won't forget ;)
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
|
Good to go |
|
yay thanks for the reviews :) |
This simplifies
check_is_fittedto error if no fitted attribute is found.This clearly is less strict than what we had before, but I did not need to change any tests, so according to our tests (i.e. the guaranteed functionality), this implementation is as good as the previous one.
The main motivation for this change is to allow us to reduce boiler-plate in the future. If we introduce a validation method as in #13603, we could now include the
check_is_fittedthere.