Skip to content

[MRG] Successive halving#4

Merged
amueller merged 21 commits intodabl:masterfrom
NicolasHug:successivehalving
Mar 27, 2019
Merged

[MRG] Successive halving#4
amueller merged 21 commits intodabl:masterfrom
NicolasHug:successivehalving

Conversation

@NicolasHug
Copy link
Copy Markdown
Collaborator

@NicolasHug NicolasHug commented Feb 7, 2019

Still very WIP, need feedback on the interface :)

I'm now overwriting _run_search, taking advantage of the proposed changes in BaseGridSearch from scikit-learn/scikit-learn#13145

Note that this forces to remove the refit parameter, which has to be set to a custom callable.

cv_results_ looks good and it's still useful. For e.g. 4 candidates it's going to have 4 (first iter) + 2 (second and last iter) rows.

If you're OK with this new design I'll start to write tests :)

@amueller
Copy link
Copy Markdown
Collaborator

amueller commented Mar 7, 2019

missing

from collections import defaultdict
from itertools import product

I think?

@amueller
Copy link
Copy Markdown
Collaborator

amueller commented Mar 7, 2019

in the random search it probably shouldn't be n_iter but n_candidates?

Also there should be some minimum size for the data for the first iteration.

@amueller
Copy link
Copy Markdown
Collaborator

amueller commented Mar 7, 2019

I wonder if we should be using a bigger validation set in the cross-validation... hm... I guess large validation sets could also slow things down. Have you checked how other implementations do this?

@amueller
Copy link
Copy Markdown
Collaborator

amueller commented Mar 7, 2019

res = pd.DataFrame(sh.cv_results_)
res['params_str'] = res.params.apply(str)
reshape = res.pivot(index='iter', columns='params_str', values='mean_test_score')
reshape.plot(legend=False, alpha=.4, c='k')

image

fml/search.py Outdated
n_candidates = len(candidate_params)
n_samples_iter = floor(n_samples_total /
(n_candidates * n_iterations))
indices = rng.choice(n_samples_total, n_samples_iter,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
indices = rng.choice(n_samples_total, n_samples_iter,
# this could be outside the for-loop
# 2 is a magic number. I found 10 too slow and 2 seems to work fine?
# basically lower bound on the test set size
cv = check_cv(self.cv, y, classifier=is_classifier(self.estimator))
min_n_samples = cv.get_n_splits(X, y) * n_classes * 2
if is_classifier(self.estimator):
n_samples_iter = max(n_samples_iter, min_n_samples)
indices = rng.choice(n_samples_total, n_samples_iter,

this makes this more robust for small datasets.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok the is_classifier here is useless and my indentation is wrong... don't commit this lol.

fml/search.py Outdated

candidate_params = list(self._generate_candidate_params())
n_iterations = int(ceil(log2(len(candidate_params))))
n_samples_total = X.shape[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so you basically set the budget to n_samples, right? where did you get that from?

@amueller
Copy link
Copy Markdown
Collaborator

tests are passing now on master :)

@amueller
Copy link
Copy Markdown
Collaborator

btw if you can make this green I'll merge it and we can iterate. That way I can play with it on my flight more easily.

@NicolasHug
Copy link
Copy Markdown
Collaborator Author

CI is queued but it should be green now (obligatory it works on my machine ;))

@NicolasHug NicolasHug mentioned this pull request Mar 27, 2019
4 tasks
@NicolasHug NicolasHug changed the title [WIP] Successive halving [MRG] Successive halving Mar 27, 2019
@amueller amueller merged commit 4b22737 into dabl:master Mar 27, 2019
@amueller
Copy link
Copy Markdown
Collaborator

thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants