Skip to content

Implement Hyperband pruner#785

Closed
crcrpar wants to merge 6 commits intooptuna:masterfrom
crcrpar:dev/hyperband
Closed

Implement Hyperband pruner#785
crcrpar wants to merge 6 commits intooptuna:masterfrom
crcrpar:dev/hyperband

Conversation

@crcrpar
Copy link
Copy Markdown
Contributor

@crcrpar crcrpar commented Dec 11, 2019

UPDATE
The original version is separated into #805 and this PR.


This PR is based on #301.

Intuitively, Hyperband (HB) eliminates the dependency on parameters of SuccessiveHalving (SH) by internally executes multiple SHs with different configurations.

Design

Study with HyperbandPruner runs in the same way as with other pruners. HyperbandPruner class maintains some number of SuccessiveHalvingPruners (= brackets) and selects a pruner for each Trial. So, the algorithm would be different from the paper to some extent. There're two challenges, 1) different trials of the same Study have to be pruned by different brackets, and 2) When sampling for a new trial, Sampler can only use the trials of the same bracket.

Major Changes

  • Study collects the list of trials (= friend_trials in the code) and set it as a Trial's attribute
    • To filter trials with some metadata, study sets the information of pruner as user_attr to Trial and uses it as a filter.
  • Trial passes its friend_trial to study.sampler
  • Sampler's sampling methods accept the list of trials as their argument

An alternative design, a new class that manages multiple Studys is implemented in https://github.com/crcrpar/optuna/tree/dev/study-manager.

@codecov-io
Copy link
Copy Markdown

codecov-io commented Dec 11, 2019

Codecov Report

Merging #785 into master will decrease coverage by 0.12%.
The diff coverage is 85.54%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #785      +/-   ##
==========================================
- Coverage   90.15%   90.02%   -0.13%     
==========================================
  Files         106      108       +2     
  Lines        8769     8906     +137     
==========================================
+ Hits         7906     8018     +112     
- Misses        863      888      +25
Impacted Files Coverage Δ
optuna/pruners/__init__.py 100% <100%> (ø) ⬆️
optuna/samplers/tpe/sampler.py 87.54% <100%> (ø) ⬆️
optuna/pruners/percentile.py 95.71% <100%> (+0.25%) ⬆️
tests/test_study.py 97.9% <100%> (ø) ⬆️
optuna/study.py 93.82% <100%> (+0.3%) ⬆️
optuna/integration/cma.py 94.03% <100%> (+0.08%) ⬆️
optuna/integration/skopt.py 88.42% <100%> (ø) ⬆️
optuna/pruners/successive_halving.py 95.23% <100%> (+0.41%) ⬆️
optuna/testing/integration.py 100% <100%> (ø) ⬆️
optuna/integration/lightgbm_tuner/optimize.py 76.03% <100%> (ø) ⬆️
... and 12 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0389962...4195bea. Read the comment docs.


return ''

def should_filter_trials(self):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is called when Study's construction and tells the study whether the study's sampler can use all the trials or not.


self.sampler = sampler or samplers.TPESampler()
self.pruner = pruner or pruners.MedianPruner()
self._should_filter_trials = self.pruner.should_filter_trials()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +529 to +572
trial_pruner_metadata = self.pruner.__class__.__name__
pruner_auxiliary_data = self.pruner.get_trial_pruner_auxiliary_data(
self._study_id, trial.number)
if pruner_auxiliary_data:
trial_pruner_metadata += pruner_auxiliary_data
trial.set_user_attr('pruner_metadata', trial_pruner_metadata)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All Trial has an attribute of what pruner is used as a user_attr to allow for filtering by pruner information.

Comment on lines +141 to +167
self._friend_trials = None # type: Optional[List[FrozenTrial]]

def _set_friend_trials(self, friend_trials):
# type: (List[FrozenTrial]) -> None

self._friend_trials = friend_trials

def _clear_friend_trials(self):
# type: () -> None

self._friend_trials = None

@property
def friend_trials(self):
# type: () -> Optional[List[FrozenTrial]]

return self._friend_trials
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the naming of friend_trials is cool 😅

@crcrpar
Copy link
Copy Markdown
Contributor Author

crcrpar commented Dec 11, 2019

As I have done in general and I want to get any feedback, so I mark this as ready for review.

@crcrpar crcrpar marked this pull request as ready for review December 11, 2019 14:33
Copy link
Copy Markdown
Contributor Author

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I have to avoid the deprecated study._study_id and use study.study_name instead as done in these comments?

raise NotImplementedError

@abc.abstractmethod
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As study_id is deprecated, should this be study_name as follows?

Suggested change
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
def get_trial_pruner_auxiliary_data(self, study_name, trial_number):


@abc.abstractmethod
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reflecting the above change.

Suggested change
# type: (int, int) -> str
# type: (str, int) -> str

budget += n / 2
return budget

def get_bracket_id(self, study_id, trial_number):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, study_id is deprecated.

Suggested change
def get_bracket_id(self, study_id, trial_number):
def get_bracket_id(self, study_name, trial_number):

return budget

def get_bracket_id(self, study_id, trial_number):
# type: (int, int) -> int
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Suggested change
# type: (int, int) -> int
# type: (str, int) -> int

# type: (int, int) -> int
"""Computes the id of bracket for a trial of `trial_number`."""

n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Suggested change
n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget
n = hash('{}_{}'.format(study_name, trial_number)) % self._resource_badget

return best_intermediate_result > p

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# type: (int, int) -> str
# type: (str, int) -> str


rung += 1

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
def get_trial_pruner_auxiliary_data(self, study_name, trial_number):

rung += 1

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# type: (int, int) -> str
# type: (str, int) -> str


return self.is_pruning

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
def get_trial_pruner_auxiliary_data(self, study_name, trial_number):

return self.is_pruning

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# type: (int, int) -> str
# type: (str, int) -> str

Copy link
Copy Markdown
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job @crcrpar!
My code review is still work in progress (Actually, I still don't understand the reason why this PR includes the change of sampler interface, and what friend_trials means.). For now, I put some minor comments.

if len(study.trials) > 1:
raise RuntimeError("`FirstTrialOnlyRandomSampler` only works on the first trial.")

return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks trials argument should be propagated.

Suggested change
return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space)
return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space, trials=trials)

sample_independent method (L80) is also.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you!

`Hyperband paper <http://www.jmlr.org/papers/volume18/16-558/16-558.pdf>`_.
"""

n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget
Copy link
Copy Markdown
Member

@c-bata c-bata Dec 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not confident but It looks trial_number % self._resource_budget is enough, right?

Suggested change
n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget
n = trial_number % self._resource_budget

@crcrpar
Copy link
Copy Markdown
Contributor Author

crcrpar commented Dec 13, 2019

@c-bata

Thank you for your review!
Being not complete is not a problem.
I really appreciate your response. 😀

why this PR includes the change of sampler interface, and what friend_trials means

I think your question is equivalent to why new argument trials: List[FrozenTrial].
A sampler, especially TPE sampler, must only reflect the trials that have the same SuccessiveHalvingPruner of HyperbandPruner as a trial that is currently being initialized.
To realize this, I added an attribute of pruner_metadata to Trial via set_uesr_atttr inside Study. Also, make Study collects appropriate trials and set them as friend_trials to the trial tentatively for the ease of trial selection.

I hope this helps you.

@c-bata
Copy link
Copy Markdown
Member

c-bata commented Dec 13, 2019

Thank you! Probably, I understand.

  1. When using SuccessiveHalvingPruner, it has no problem to use the last intermediate score of pruned trials.
  2. But it seems that HyperbandPruner is not.
  3. So you labeled trials as pruned_metadata (this is a string representation of bracket_id in HyperbandPruner) by get_trial_pruner_auxiliary_data() method.
  4. friend_trials returns the trials which has the same pruned_metadata (bracket_id).
  5. So you passes friend_trials to samplers.

@crcrpar crcrpar mentioned this pull request Dec 18, 2019
@crcrpar
Copy link
Copy Markdown
Contributor Author

crcrpar commented Dec 18, 2019

This PR consists of two major changes

  • new argument of trials to samplers sample methods
  • hyperband

Both changes are not trivial I think, thus I'd like to separate this into two PRs.

@sile
Copy link
Copy Markdown
Member

sile commented Dec 18, 2019

Both changes are not trivial I think, thus I'd like to separate this into two PRs.

Nice idea!

@sile
Copy link
Copy Markdown
Member

sile commented Dec 27, 2019

I think that this PR was taken over by #809. Could we close this? > @crcrpar

@crcrpar
Copy link
Copy Markdown
Contributor Author

crcrpar commented Dec 27, 2019

I think that this PR was taken over by #809. Could we close this? > @crcrpar

thank you for your reminding, of course.

@crcrpar crcrpar closed this Dec 27, 2019
@crcrpar crcrpar deleted the dev/hyperband branch January 22, 2020 12:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants