Conversation
1bbf155 to
d97f324
Compare
Codecov Report
@@ Coverage Diff @@
## master #785 +/- ##
==========================================
- Coverage 90.15% 90.02% -0.13%
==========================================
Files 106 108 +2
Lines 8769 8906 +137
==========================================
+ Hits 7906 8018 +112
- Misses 863 888 +25
Continue to review full report at Codecov.
|
|
|
||
| return '' | ||
|
|
||
| def should_filter_trials(self): |
There was a problem hiding this comment.
This method is called when Study's construction and tells the study whether the study's sampler can use all the trials or not.
|
|
||
| self.sampler = sampler or samplers.TPESampler() | ||
| self.pruner = pruner or pruners.MedianPruner() | ||
| self._should_filter_trials = self.pruner.should_filter_trials() |
There was a problem hiding this comment.
| trial_pruner_metadata = self.pruner.__class__.__name__ | ||
| pruner_auxiliary_data = self.pruner.get_trial_pruner_auxiliary_data( | ||
| self._study_id, trial.number) | ||
| if pruner_auxiliary_data: | ||
| trial_pruner_metadata += pruner_auxiliary_data | ||
| trial.set_user_attr('pruner_metadata', trial_pruner_metadata) |
There was a problem hiding this comment.
All Trial has an attribute of what pruner is used as a user_attr to allow for filtering by pruner information.
| self._friend_trials = None # type: Optional[List[FrozenTrial]] | ||
|
|
||
| def _set_friend_trials(self, friend_trials): | ||
| # type: (List[FrozenTrial]) -> None | ||
|
|
||
| self._friend_trials = friend_trials | ||
|
|
||
| def _clear_friend_trials(self): | ||
| # type: () -> None | ||
|
|
||
| self._friend_trials = None | ||
|
|
||
| @property | ||
| def friend_trials(self): | ||
| # type: () -> Optional[List[FrozenTrial]] | ||
|
|
||
| return self._friend_trials |
There was a problem hiding this comment.
I don't think the naming of friend_trials is cool 😅
|
As I have done in general and I want to get any feedback, so I mark this as ready for review. |
crcrpar
left a comment
There was a problem hiding this comment.
Do I have to avoid the deprecated study._study_id and use study.study_name instead as done in these comments?
optuna/pruners/base.py
Outdated
| raise NotImplementedError | ||
|
|
||
| @abc.abstractmethod | ||
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): |
There was a problem hiding this comment.
As study_id is deprecated, should this be study_name as follows?
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): | |
| def get_trial_pruner_auxiliary_data(self, study_name, trial_number): |
optuna/pruners/base.py
Outdated
|
|
||
| @abc.abstractmethod | ||
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): | ||
| # type: (int, int) -> str |
There was a problem hiding this comment.
Reflecting the above change.
| # type: (int, int) -> str | |
| # type: (str, int) -> str |
optuna/pruners/hyperband.py
Outdated
| budget += n / 2 | ||
| return budget | ||
|
|
||
| def get_bracket_id(self, study_id, trial_number): |
There was a problem hiding this comment.
As mentioned above, study_id is deprecated.
| def get_bracket_id(self, study_id, trial_number): | |
| def get_bracket_id(self, study_name, trial_number): |
optuna/pruners/hyperband.py
Outdated
| return budget | ||
|
|
||
| def get_bracket_id(self, study_id, trial_number): | ||
| # type: (int, int) -> int |
There was a problem hiding this comment.
ditto
| # type: (int, int) -> int | |
| # type: (str, int) -> int |
optuna/pruners/hyperband.py
Outdated
| # type: (int, int) -> int | ||
| """Computes the id of bracket for a trial of `trial_number`.""" | ||
|
|
||
| n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget |
There was a problem hiding this comment.
ditto
| n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget | |
| n = hash('{}_{}'.format(study_name, trial_number)) % self._resource_badget |
optuna/pruners/percentile.py
Outdated
| return best_intermediate_result > p | ||
|
|
||
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): | ||
| # type: (int, int) -> str |
There was a problem hiding this comment.
| # type: (int, int) -> str | |
| # type: (str, int) -> str |
optuna/pruners/successive_halving.py
Outdated
|
|
||
| rung += 1 | ||
|
|
||
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): |
There was a problem hiding this comment.
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): | |
| def get_trial_pruner_auxiliary_data(self, study_name, trial_number): |
optuna/pruners/successive_halving.py
Outdated
| rung += 1 | ||
|
|
||
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): | ||
| # type: (int, int) -> str |
There was a problem hiding this comment.
| # type: (int, int) -> str | |
| # type: (str, int) -> str |
optuna/testing/integration.py
Outdated
|
|
||
| return self.is_pruning | ||
|
|
||
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): |
There was a problem hiding this comment.
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): | |
| def get_trial_pruner_auxiliary_data(self, study_name, trial_number): |
optuna/testing/integration.py
Outdated
| return self.is_pruning | ||
|
|
||
| def get_trial_pruner_auxiliary_data(self, study_id, trial_number): | ||
| # type: (int, int) -> str |
There was a problem hiding this comment.
| # type: (int, int) -> str | |
| # type: (str, int) -> str |
| if len(study.trials) > 1: | ||
| raise RuntimeError("`FirstTrialOnlyRandomSampler` only works on the first trial.") | ||
|
|
||
| return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space) |
There was a problem hiding this comment.
It looks trials argument should be propagated.
| return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space) | |
| return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space, trials=trials) |
sample_independent method (L80) is also.
There was a problem hiding this comment.
Good catch, thank you!
optuna/pruners/hyperband.py
Outdated
| `Hyperband paper <http://www.jmlr.org/papers/volume18/16-558/16-558.pdf>`_. | ||
| """ | ||
|
|
||
| n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget |
There was a problem hiding this comment.
I'm not confident but It looks trial_number % self._resource_budget is enough, right?
| n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget | |
| n = trial_number % self._resource_budget |
|
Thank you for your review!
I think your question is equivalent to why new argument I hope this helps you. |
|
Thank you! Probably, I understand.
|
update other samplers
update tests
|
This PR consists of two major changes
Both changes are not trivial I think, thus I'd like to separate this into two PRs. |
Nice idea! |
d3f749a to
4195bea
Compare
UPDATE
The original version is separated into #805 and this PR.
This PR is based on #301.
Intuitively, Hyperband (HB) eliminates the dependency on parameters of SuccessiveHalving (SH) by internally executes multiple SHs with different configurations.
Design
StudywithHyperbandPrunerruns in the same way as with other pruners.HyperbandPrunerclass maintains some number ofSuccessiveHalvingPruners (= brackets) and selects a pruner for eachTrial. So, the algorithm would be different from the paper to some extent. There're two challenges, 1) different trials of the sameStudyhave to be pruned by different brackets, and 2) When sampling for a new trial,Samplercan only use the trials of the same bracket.Major Changes
Studycollects the list of trials (=friend_trialsin the code) and set it as aTrial's attributeuser_attrtoTrialand uses it as a filter.Trialpasses itsfriend_trialtostudy.samplerSampler's sampling methods accept the list of trials as their argumentAn alternative design, a new class that manages multiple
Studys is implemented in https://github.com/crcrpar/optuna/tree/dev/study-manager.