Skip to content

Stop using hash function in _get_bracket_id in HyperbandPruner#4131

Merged
not522 merged 3 commits intooptuna:masterfrom
zaburo-ch:stop-using-hash-function-in-get-bracketid-of-hyperband
Nov 6, 2022
Merged

Stop using hash function in _get_bracket_id in HyperbandPruner#4131
not522 merged 3 commits intooptuna:masterfrom
zaburo-ch:stop-using-hash-function-in-get-bracketid-of-hyperband

Conversation

@zaburo-ch
Copy link
Copy Markdown
Contributor

@zaburo-ch zaburo-ch commented Nov 6, 2022

Motivation

Fix the bug reported in #3083 that _get_bracket_id of Hyperband returns different results depending on the process.

Description of the changes

Simply stopped using hash as suggested in #809 (comment).

@toshihikoyanase toshihikoyanase added sprint-20221106 PR from the online/offline hybrid sprint event Nov 06, 2022 bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself. labels Nov 6, 2022
@zaburo-ch
Copy link
Copy Markdown
Contributor Author

I am concerned that if n_trials is not sufficiently larger than total_trial_allocation_budget, the bracket_id will be biased towards smaller ones by this change. This happens when max_resource is large and n_trials is small, so I wonder if it would make a difference in results in NN hyperparameter optimization.

from collections import defaultdict

import optuna as ot


def objective(t):
    x = t.suggest_float("x", 0, 1)

    for step in range(100):
        t.report(step=step, value=x * step)
        if t.should_prune():
            raise ot.TrialPruned

    return x


if __name__ == "__main__":
    p = ot.pruners.HyperbandPruner()
    study = ot.create_study(
        study_name="test_hyperband_3083",
        pruner=p,  # You actually don't need to prune, for this demonstration.
        load_if_exists=True,
    )
    study.optimize(objective, n_trials=100)

    # Check bracket splits. Using private APIs, only use for testing.
    brackets_to_trial_numbers = defaultdict(list)
    for t in study.trials:
        brackets_to_trial_numbers[study.pruner._get_bracket_id(study, t)].append(t.number)
    print({k: len(v) for k, v in brackets_to_trial_numbers.items()})
    # {0: 81, 1: 19}

Copy link
Copy Markdown
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sorry. On second thought, as you said, the bracket_id will be biased due to this change. Let me review again 🙇

@github-actions github-actions bot added the optuna.pruners Related to the `optuna.pruners` submodule. This is automatically labeled by github-actions. label Nov 6, 2022
@c-bata c-bata self-assigned this Nov 6, 2022
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Nov 6, 2022

Codecov Report

Merging #4131 (2c1f22d) into master (bccb63d) will decrease coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #4131      +/-   ##
==========================================
- Coverage   90.10%   90.09%   -0.02%     
==========================================
  Files         161      161              
  Lines       12655    12656       +1     
==========================================
- Hits        11403    11402       -1     
- Misses       1252     1254       +2     
Impacted Files Coverage Δ
optuna/pruners/_hyperband.py 98.82% <100.00%> (+0.01%) ⬆️
optuna/integration/botorch.py 97.37% <0.00%> (-0.88%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@zaburo-ch
Copy link
Copy Markdown
Contributor Author

To avoid the bias discussed above, I modified the implementation to use crc32, which was mentioned in #3083.

The result of the above code after modification is {3: 9, 1: 21, 0: 58, 4: 2, 2: 10}

Copy link
Copy Markdown
Member

@c-bata c-bata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! LGTM.

Copy link
Copy Markdown
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this note?

.. note::
``HyperbandPruner`` in Optuna randomly computes bracket ID for each trial with a hash
function taking ``study_name`` of :class:`~optuna.study.Study` and
:attr:`~optuna.trial.Trial.number`. Please specify ``study_name`` and
`hash seed <https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_
to make pruning behavior reproducible.

@not522 not522 self-assigned this Nov 6, 2022
@zaburo-ch
Copy link
Copy Markdown
Contributor Author

Thanks for the review. I simply deleted the note, is this ok?

Copy link
Copy Markdown
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM!

@not522 not522 enabled auto-merge November 6, 2022 07:45
@not522 not522 added this to the v3.1.0 milestone Nov 6, 2022
@not522 not522 changed the title Stop using hash function in _get_bracket_id in Hyperband Stop using hash function in _get_bracket_id in HyperbandPruner Nov 6, 2022
@not522 not522 merged commit f454359 into optuna:master Nov 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself. optuna.pruners Related to the `optuna.pruners` submodule. This is automatically labeled by github-actions. sprint-20221106 PR from the online/offline hybrid sprint event Nov 06, 2022

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants