[MRG] Allow for refit=callable in *SearchCV to add flexibility in identifying the best estimator #11269#11354
[MRG] Allow for refit=callable in *SearchCV to add flexibility in identifying the best estimator #11269#11354jnothman merged 55 commits intoscikit-learn:masterfrom wenhaoz-fengcai:master
Conversation
sklearn/model_selection/_search.py
Outdated
| "refit should be set to False " | ||
| "explicitly. %r was passed" | ||
| % self.refit) | ||
| refit_metric = scorer_key |
There was a problem hiding this comment.
I don't get why you need this. you don't use refit_metric if refit is callable below. I also think making inferences from the name of the function is inappropriate.
There was a problem hiding this comment.
I think refit_metric is needed to compute self.best_score_ as shown here in the original code base.
There was a problem hiding this comment.
Ah, I see. I would just disable best_score_ when refit is callable. Please test and document that behaviour.
sklearn/model_selection/_search.py
Outdated
|
|
||
| Where there are considerations other than maximum model performance in | ||
| choosing a best estimator, ``refit`` can be set to a function which returns | ||
| thre selected ``best_index_`` given ``cv_results_``. |
sklearn/model_selection/_search.py
Outdated
| scorer is used to find the best parameters for refitting the estimator | ||
| at the end. | ||
|
|
||
| Where there are considerations other than maximum model performance in |
sklearn/model_selection/_search.py
Outdated
| scorer that would be used to find the best parameters for refitting | ||
| the estimator at the end. | ||
|
|
||
| Where there are considerations other than maximum model performance in |
There was a problem hiding this comment.
Please use the same text and formatting in both places.
sklearn/model_selection/_search.py
Outdated
| ``best_score_`` and ``best_parameters_`` will only be available if | ||
| ``refit`` is set and all of them will be determined w.r.t this specific | ||
| scorer. | ||
| scorer. If a callable is passed to parameter refit, the function's name |
There was a problem hiding this comment.
This is an unnecessary and unhelpful condition.
| For multi-metric evaluation, the name of refit callable function must | ||
| end with a scorer key(`_<scorer_name>`). | ||
| """ | ||
| def refit_prec(cv_results): |
There was a problem hiding this comment.
We should have a realistic example in examples/model_selection/ rather than here.
As a simple example, I would consider using maximising score while minimising the number of selected features or PCA components.
Here we should merely be testing interface, and a dummy function (for instance, one that always chooses the lowest-score model) is sufficient / most appropriate, as it is then easy for us to be sure what correct behaviour is.
There was a problem hiding this comment.
@jnothman Would you say a dummy function like below is good enough to test our interface?
def refit_callable(cv_results):
return cv_results['mean_test_score'].argmin()It seems that you're suggesting two things here :(
There was a problem hiding this comment.
Yes, that looks good. I might add to that an assertion that all the keys we expect to be in results are in there.
Yes, I am indeed suggesting a second thing here. An example in examples/model_selection will hugely increase the visibility and practical usability of this feature. The example gallery is how we advise users how to use the features described in technical detail in the docstrings (and before StackOverflow has all the answers).
There was a problem hiding this comment.
Thanks! I'm adding a example from examples/model_selection for this feature in the docstring.
There was a problem hiding this comment.
@jnothman is it appropriate to add one more example for refit=callable in the docstring under GridSearchCV class after this one?
scikit-learn/sklearn/model_selection/_search.py
Lines 931 to 958 in 3b5abf7
There was a problem hiding this comment.
I think a meaningful example is too large, and too much of a power-user feature, to be in the docstring.
There was a problem hiding this comment.
@jnothman It seems that we dont need to write test cases for our example under examples directory, right? ;)
|
Feel free to use GitHub's todo list feature in the PR description. |
|
@jnothman Thanks for your input! I'll improve my implementation based on your feedback. |
| enumerate(cv_results['mean_test_prec'])} | ||
| # Select models which have test precisions within 1 standard deviation | ||
| # of the best 'mean_test_prec' | ||
| candidates = dict(filter(lambda i: (i[1] >= test_prec_lower |
There was a problem hiding this comment.
btw, a dict comprehension is much easier to read than this
So is test_prec_upper > i[1] >= test_prec_lower
| enumerate(cv_results['mean_fit_time'])} | ||
| fit_time_rank = sorted(fit_time) | ||
| for i in fit_time_rank: | ||
| if fit_time[i] in candidates: |
There was a problem hiding this comment.
This isn't working in AppVeyor. The function is returning None there.
There was a problem hiding this comment.
Yes, I'm replacing these two test cases with simpler ones.
jnothman
left a comment
There was a problem hiding this comment.
Circle CI should fail if the example does
jnothman
left a comment
There was a problem hiding this comment.
Please reference the example from doc/modules/grid_search.rst. you should probably put the motivation / use case there more than in the example
|
Documentation is rendered at https://26300-843222-gh.circle-artifacts.com/0/doc/_changed.html |
| } | ||
| ] | ||
|
|
||
| grid = GridSearchCV(pipe, cv=3, n_jobs=1, param_grid=param_grid, |
There was a problem hiding this comment.
I don't think we should be encouraging users to calculate a standard deviation over 3 samples. Make cv=10.
| interface can also be used in multiple metrics evaluation. | ||
|
|
||
| This example balances model complexity and cross-validated score by | ||
| finding a decent accuracy within 1 standard deviation of the best accuracy |
There was a problem hiding this comment.
You might want to say that this is a rule of thumb for insignificant difference.
We could determine insignificant difference in a more proper way, such as with a wilcoxon rank-sum test
| @@ -0,0 +1,125 @@ | |||
| """ | |||
| ======================================================================= | |||
| Balance model complexity and cross-validated score using refit=callable | |||
| upper/lower bounds within 1 standard deviation of the | ||
| best `mean_test_scores`. | ||
| """ | ||
| std_test_score = np.std(scores) |
There was a problem hiding this comment.
Should be using std_test_score: you want standard deviation across cv splits, not across parameter candidates
|
@jnothman @adrinjalali Probably need your help to fix travis-ci issue... :-/ |
adrinjalali
left a comment
There was a problem hiding this comment.
Thanks @jiaowoshabi , LGTM!
|
Awesome, @jiaowoshabi! Please add an entry to the change log at |
doc/whats_new/v0.21.rst
Outdated
| :func:`~model_selection.validation_curve` only the latter is required. | ||
| :issue:`12613` and :issue:`12669` by :user:`Marc Torrellas <marctorrellas>`. | ||
|
|
||
| - |Enhancement| :class:`~model_selection.BaseSearchCV` now allows for |
There was a problem hiding this comment.
BaseSearchCV is not listed in doc/modules/classes.rst so this link won't work. Ordinarily we'd reference GridSearchCV and RandomizedSearchCV. you could also consider referencing the user guide rather than the example?
sklearn/model_selection/_search.py
Outdated
| See ``scoring`` parameter to know more about multiple metric | ||
| evaluation. | ||
|
|
||
| .. versionadded:: 0.20 |
There was a problem hiding this comment.
I think versionchanged may be more appropriate, since the parameter was not added.
sklearn/model_selection/_search.py
Outdated
| evaluation. | ||
|
|
||
| .. versionadded:: 0.20 | ||
| GridSearchCV supports ``refit`` = callable to add flexibility in |
There was a problem hiding this comment.
Don't mention GridSearchCV here. Simply say "Support for callable added." the rest is documented above.
|
Thanks @jiaowoshabi!! |
| self.best_index_ = self.refit(results) | ||
| if not isinstance(self.best_index_, (int, np.integer)): | ||
| raise TypeError('best_index_ returned is not an integer') | ||
| if self.best_index_ < 0 or self.best_index_ >= len(results): |
There was a problem hiding this comment.
Pretty sure this is a bug: results is a dictionary of things, and each value is an array the size of the grid.
…ng the best estimator (scikit-learn#11354)" This reverts commit b4f76cf.
…ng the best estimator (scikit-learn#11354)" This reverts commit b4f76cf.
Reference Issues/PRs
Fixes #11269. Fixes #12865. See also #9499
What does this implement/fix? Explain your changes.
Allow a callable to be passed to refit in *SearchCV to balance score and model complexity. This interface adds flexibility in identifying the "best" estimator. The function passed to parameter
refitincorporate of which metric to optimise. Hence users can use multi-metric evaluation with this interface.
Any other comments?
mean_test_score_search.py under model_selection directory)plot_grid_search_refit_callable.py) of demonstrating the usage of this interface underexamples/model_selection/makeChecklist:
refit=callableusing simple dummy refit function.refit=callableusing similar example in multi-metric eval settings_search.pyto pass the above tests