Fix a bug in distributed optimization using NSGA-II/III#6066
Fix a bug in distributed optimization using NSGA-II/III#6066y0z merged 13 commits intooptuna:masterfrom
Conversation
The cached parent stores as trial id but retrieves by index, created a quick work around
|
@leevers |
import optuna
from optuna.samplers import NSGAIISampler
storage_path = 'sqlite:///optuna_test_6066.db'
def objective(trial):
test_param_1 = trial.suggest_categorical(
"test_param_1",
[False, True]
)
value = 100 * (1 if test_param_1 else 0)
return value
# Do a study, this will fill the first 10 initial trial_ids in the database
study_1 = optuna.create_study(
direction='maximize',
load_if_exists=True,
storage=storage_path,
study_name='optuna_test_6066_study_1')
study_1.optimize(objective, n_trials=10, catch=(RuntimeError, ValueError, AssertionError))
# Run another study, this way the trial_id does not match the index
study_2 = optuna.create_study(
direction='maximize',
sampler=NSGAIISampler(
population_size=4,
mutation_prob=0.1,
crossover_prob=0.9,
swapping_prob=0.5),
load_if_exists=True,
storage=storage_path,
study_name='optuna_test_6066_study_2')
study_2.optimize(objective, n_trials=10, catch=(RuntimeError, ValueError, AssertionError)) |
|
This could be related to #6065 as this doesn't occur when RDBStorage is not used |
y0z
left a comment
There was a problem hiding this comment.
A CI error exists.
https://github.com/optuna/optuna/actions/runs/14723994781/job/41322869127?pr=6066
FAILED tests/samplers_tests/test_base_gasampler.py::test_get_parent_population[args0] - KeyError: 0
Please update the corresponding test as follows.
def test_get_parent_population(args: dict[str, Any]) -> None:
test_sampler = BaseGASamplerTestSampler(population_size=3)
mock_study = MagicMock()
mock_study._storage.get_study_system_attrs.return_value = args["study_system_attrs"]
if args["cache"]:
mock_study._get_trials.return_value = [
optuna.trial.FrozenTrial(
number=i,
trial_id=i,
state=optuna.trial.TrialState.WAITING,
value=None,
datetime_start=None,
datetime_complete=None,
params={},
distributions={},
user_attrs={},
system_attrs={},
intermediate_values={},
values=None,
) for i in args["study_system_attrs"][BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "1"]
]
with patch.object(
BaseGASamplerTestSampler,
"select_parent",
return_value=[Mock(_trial_id=i) for i in args["parent_population"]],
) as mock_select_parent:
return_value = test_sampler.get_parent_population(mock_study, args["generation"])
if args["generation"] == 0:
assert mock_select_parent.call_count == 0
assert mock_study._storage.get_study_system_attrs.call_count == 0
assert mock_study._get_trials.call_count == 0
assert return_value == []
return
mock_study._storage.get_study_system_attrs.assert_called_once_with(mock_study._study_id)
if not args["cache"]:
mock_select_parent.assert_called_once_with(mock_study, args["generation"])
mock_study._storage.set_study_system_attr.assert_called_once_with(
mock_study._study_id,
BaseGASamplerTestSampler._get_parent_cache_key_prefix() + str(args["generation"]),
[i._trial_id for i in mock_select_parent.return_value],
)
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6066 +/- ##
==========================================
- Coverage 88.43% 88.32% -0.11%
==========================================
Files 206 207 +1
Lines 13910 13978 +68
==========================================
+ Hits 12301 12346 +45
- Misses 1609 1632 +23 ☔ View full report in Codecov by Sentry. |
|
I have updated the test and got it to a passing state |
| optuna.trial.FrozenTrial( | ||
| number=i, | ||
| trial_id=i, | ||
| state=optuna.trial.TrialState.WAITING, | ||
| value=None, | ||
| datetime_start=None, | ||
| datetime_complete=None, | ||
| params={}, | ||
| distributions={}, | ||
| user_attrs={}, | ||
| system_attrs={}, | ||
| intermediate_values={}, | ||
| values=None, | ||
| ) |
There was a problem hiding this comment.
[NIT]
There’s a test utility function for creating FrozenTrial objects:
optuna/optuna/testing/trials.py
Lines 13 to 34 in 022f97d
There was a problem hiding this comment.
to test this properly we actually need trials that have a different trial_id to their index in the trials array
There was a problem hiding this comment.
Could you add a new test case to test this condition?
There was a problem hiding this comment.
i don't know enough about mocking etc to know how
There was a problem hiding this comment.
How about changing the first test case to
{
"study_system_attrs": {
BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "2": [3, 4, 5]
},
"parent_population": [3, 4, 5],
"generation": 3,
"cache": True,
},and mock_study._get_trials to the below?
if args["cache"]:
mock_study._get_trials.return_value = [
optuna.trial.FrozenTrial(
number=i,
trial_id=j,
state=optuna.trial.TrialState.WAITING,
value=None,
datetime_start=None,
datetime_complete=None,
params={},
distributions={},
user_attrs={},
system_attrs={},
intermediate_values={},
values=None,
)
for i, j in enumerate(args["study_system_attrs"][
BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "2"
])
]Then, the return value has trials with trial_number != trial_id.
[FrozenTrial(number=0, state=4, values=None, datetime_start=None, datetime_complete=None, params={}, user_attrs={}, system_attrs={}, int
ermediate_values={}, distributions={}, trial_id=3, value=None), FrozenTrial(number=1, state=4, values=None, datetime_start=None, datetime_complete=None, params={}, user_attrs={}, s
ystem_attrs={}, intermediate_values={}, distributions={}, trial_id=4, value=None), FrozenTrial(number=2, state=4, values=None, datetime_start=None, datetime_complete=None, params={
}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={}, trial_id=5, value=None)There was a problem hiding this comment.
Since get_parent_population() only accesses the _trial_id attribute, creating full FrozenTrial instances is unnecessary. Simple mocks with just _trial_id are sufficient for this test:
if args["cache"]:
mock_study._get_trials.return_value = [
Mock(_trial_id=i)
for i in args["study_system_attrs"][
BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "1"
]
]It’s a bit more complicated, but for additional safety, it’s useful to enforce that only _trial_id is accessed. This can be done with a strict mock:
if args["cache"]:
class StrictTrialMock(MagicMock):
def __getattr__(self, name):
if name != "_trial_id":
raise AttributeError(
f"Access to attribute '{name}' is not allowed."
)
return super().__getattr__(name)
mock_study._get_trials.return_value = [
StrictTrialMock(_trial_id=i)
for i in args["study_system_attrs"][
BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "1"
]
]This setup ensures that get_parent_population() strictly depends only on _trial_id, making the test independent of the trial number or any other attributes.
|
@sawa3030 Could you review this PR? |
update based on suggestions from optuna#6066 (comment)
Co-authored-by: Shuhei Watanabe <47781922+nabenabe0928@users.noreply.github.com>
Co-authored-by: Eri Sawada <125344906+sawa3030@users.noreply.github.com>
black suggestions
mypy tests/samplers_tests/test_base_gasampler.py:220: error: Function is missing a type annotation [no-untyped-def]
tests/samplers_tests/test_base_gasampler.py:220: error: Function is missing a return type annotation [no-untyped-def]
| mock_study._get_trials.assert_has_calls( | ||
| [call(deepcopy=False)] + [call().__getitem__(i) for i in args["parent_population"]] | ||
| ) | ||
| assert return_value == mock_study._get_trials.return_value |
There was a problem hiding this comment.
Adding a test to confirm that the cache is used instead of calling select_parent is helpful.
| assert return_value == mock_study._get_trials.return_value | |
| assert mock_select_parent.call_count == 0 | |
| assert return_value == mock_study._get_trials.return_value |
| class StrictTrialMock(MagicMock): | ||
| def __getattr__(self, name: str) -> Any: | ||
| if name != "_trial_id": | ||
| raise AttributeError(f"Access to attribute '{name}' is not allowed.") | ||
| return super().__getattr__(name) |
There was a problem hiding this comment.
Perhaps, Mock does not support overriding __getattr__.
So, how about simply using MagicMock here to avoid a tricky test?
Actually, I've confirmed a curious behavior, i.e., the name passed to the __getattr__ is replaced with _mock_methods , in the code below, and the current implementation does not work as expected.
class StrictTrialMock(MagicMock):
def __getattr__(self, name: str) -> Any:
print(name)
if name != "_trial_id":
raise AttributeError(f"Access to attribute '{name}' is not allowed.")
return super().__getattr__(name)pytest -s tests/samplers_tests/test_base_gasampler.py::test_get_parent_population
=================================================================================== test session starts ====================================================================================
platform darwin -- Python 3.11.7, pytest-7.4.4, pluggy-1.3.0
rootdir: /Users/yozaki/work/review/optuna
configfile: pyproject.toml
plugins: jaxtyping-0.2.30, typeguard-2.13.3
collected 4 items
tests/samplers_tests/test_base_gasampler.py _mock_methods
_mock_methods
_mock_methods|
This pull request has not seen any recent activity. |
y0z
left a comment
There was a problem hiding this comment.
I think changing tests can be done as a follow-up task.
The logic fix is appropriate, so I merge this PR.
LGTM.
|
Let me work on a follow-up PR to update the tests. |
The cached parent stores as trial id but retrieves by index, created a quick work around
Motivation
I am using optuna with the NSGAIISampler and this is causing a crash
Description of the changes
Dirty hack so that it can continue