Alternative implementation to hide the interface so that all samplers can use HyperbandPruner.#1196
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1196 +/- ##
==========================================
+ Coverage 85.93% 85.97% +0.03%
==========================================
Files 92 92
Lines 6747 6751 +4
==========================================
+ Hits 5798 5804 +6
+ Misses 949 947 -2 ☔ View full report in Codecov by Sentry. |
hvy
left a comment
There was a problem hiding this comment.
Thanks for the PR. I left some comments. Also, I think we can test this logic, it might not be the cleanest approach but you could e.g. patch _get_bracket_id and count the number of trials.
|
How about testing that filtering works for each sampler? @pytest.mark.parametrize(
"sampler_cls,sampler_kwargs",
[
(optuna.samplers.RandomSampler, {}),
(optuna.samplers.TPESampler, {"n_startup_trials": 1}),
(
optuna.samplers.GridSampler,
{"search_space": {"value": numpy.linspace(0.0, 1.0, 10, endpoint=False).tolist()}},
),
(optuna.samplers.CmaEsSampler, {"n_startup_trials": 1}),
],
)
def test_hyperband_filter_study(sampler_cls, sampler_kwargs):
# type: () -> None
def objective(trial: optuna.trial.Trial) -> float:
return trial.suggest_uniform("value", 0.0, 1.0)
n_trials = 10
n_brackets = 2
expected_n_trials_per_bracket = n_trials // n_brackets
with mock.patch(
"optuna.pruners.HyperbandPruner._get_bracket_id",
new=mock.Mock(side_effect=lambda study, trial: trial.number % n_brackets),
):
for method_name in [
"infer_relative_search_space",
"sample_relative",
"sample_independent",
]:
sampler = sampler_cls(**sampler_kwargs)
pruner = optuna.pruners.HyperbandPruner(
min_resource=MIN_RESOURCE,
max_resource=MAX_RESOURCE,
reduction_factor=REDUCTION_FACTOR,
)
with mock.patch(
"optuna.samplers.{}.{}".format(sampler_cls.__name__, method_name),
wraps=getattr(sampler, method_name),
) as method_mock:
study = optuna.study.create_study(sampler=sampler, pruner=pruner)
study.optimize(objective, n_trials=n_trials)
args = method_mock.call_args[0]
study = args[0]
trials = study.get_trials()
assert len(trials) == expected_n_trials_per_bracket |
toshihikoyanase
left a comment
There was a problem hiding this comment.
Thank you for your update. I have two comments.
[Note] #1210 also addresses the refactoring of TrialState. It may change the module of TrialState.
|
@toshihikoyanase Thank you for your insightful comments. I updated codes according to your comments.
Thank you for the information. I will carefully merge this PR and #1210. |
|
@HideakiImamura Thank you for reflecting my suggestions! I have an additional comment. I think it is worth mentioning it somewhere in the code for future development. For instance, how about adding a test case to confirm that |
hvy
left a comment
There was a problem hiding this comment.
Thanks for the fixes! Changes basically LGTM but I left some nitpicks.
|
Could you fix the conflict as well? |
|
@hvy @toshihikoyanase Thank you for your reviews. I updated according to your comments. |
toshihikoyanase
left a comment
There was a problem hiding this comment.
Thank you for your update. LGTM!
|
Ouch, another conflict. |
HyperbandPrunerHyperbandPruner.
|
Thanks for the long running effort. The CI was all green so I merged your changes. |
HyperbandPruner.HyperbandPruner.
Motivation
The current Hyperband implementation in Optuna imposes users who want to implement a new sampler to implement the specific logic. This annoys us because it is exactly the same logic for all samplers. To hide the interface so that all samplers can use HyperbandPruner without any additional efforts, PR #1168 which is already opened needs to fix the interface of
study.sampler. This PR aims to provide an alternative choice to implement the functionality without any changes in interfaces.Description of the changes
optuna/samplers/tpe/sampler.pytopruners/__init__.pyaspruners.filter_study()function.optuna/trial.py.optuna/_trial_state.pyand moveTrialStateclass fromoptuna/trial.pytooptuna/_trial_state.py.Note
This PR opposite PR #1168. If PR #1168 is merged, this PR should not be merged, and vice versa.