Skip to content

Fix a bug in distributed optimization using NSGA-II/III#6066

Merged
y0z merged 13 commits intooptuna:masterfrom
leevers:patch-1
May 30, 2025
Merged

Fix a bug in distributed optimization using NSGA-II/III#6066
y0z merged 13 commits intooptuna:masterfrom
leevers:patch-1

Conversation

@leevers
Copy link
Copy Markdown
Contributor

@leevers leevers commented Apr 29, 2025

The cached parent stores as trial id but retrieves by index, created a quick work around

Motivation

I am using optuna with the NSGAIISampler and this is causing a crash

Description of the changes

Dirty hack so that it can continue

The cached parent stores as trial id but retrieves by index, created a quick work around
@nabenabe0928
Copy link
Copy Markdown
Contributor

@leevers
Could you give us the code to reproduce your error?

@leevers
Copy link
Copy Markdown
Contributor Author

leevers commented Apr 30, 2025

import optuna
from optuna.samplers import NSGAIISampler

storage_path = 'sqlite:///optuna_test_6066.db'

def objective(trial):
    test_param_1 = trial.suggest_categorical(
        "test_param_1",
        [False, True]
    )
    
    value = 100 * (1 if test_param_1 else 0)
       
    return value

# Do a study, this will fill the first 10 initial trial_ids in the database
study_1 = optuna.create_study(
        direction='maximize',
        load_if_exists=True,
        storage=storage_path,
        study_name='optuna_test_6066_study_1')

study_1.optimize(objective, n_trials=10, catch=(RuntimeError, ValueError, AssertionError))

# Run another study, this way the trial_id does not match the index
study_2 = optuna.create_study(
        direction='maximize',
        sampler=NSGAIISampler(
            population_size=4,
            mutation_prob=0.1,
            crossover_prob=0.9,
            swapping_prob=0.5),
        load_if_exists=True,
        storage=storage_path,
        study_name='optuna_test_6066_study_2')

study_2.optimize(objective, n_trials=10, catch=(RuntimeError, ValueError, AssertionError))

@leevers
Copy link
Copy Markdown
Contributor Author

leevers commented Apr 30, 2025

This could be related to #6065 as this doesn't occur when RDBStorage is not used

@nabenabe0928
Copy link
Copy Markdown
Contributor

nabenabe0928 commented May 1, 2025

@y0z @gen740 Could you review this PR when you have time?

Copy link
Copy Markdown
Member

@y0z y0z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A CI error exists.
https://github.com/optuna/optuna/actions/runs/14723994781/job/41322869127?pr=6066

FAILED tests/samplers_tests/test_base_gasampler.py::test_get_parent_population[args0] - KeyError: 0

Please update the corresponding test as follows.

def test_get_parent_population(args: dict[str, Any]) -> None:
    test_sampler = BaseGASamplerTestSampler(population_size=3)

    mock_study = MagicMock()
    mock_study._storage.get_study_system_attrs.return_value = args["study_system_attrs"]
    if args["cache"]:
        mock_study._get_trials.return_value = [
            optuna.trial.FrozenTrial(
                number=i,
                trial_id=i,
                state=optuna.trial.TrialState.WAITING,
                value=None,
                datetime_start=None,
                datetime_complete=None,
                params={},
                distributions={},
                user_attrs={},
                system_attrs={},
                intermediate_values={},
                values=None,
            ) for i in args["study_system_attrs"][BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "1"]
        ]

    with patch.object(
        BaseGASamplerTestSampler,
        "select_parent",
        return_value=[Mock(_trial_id=i) for i in args["parent_population"]],
    ) as mock_select_parent:
        return_value = test_sampler.get_parent_population(mock_study, args["generation"])

    if args["generation"] == 0:
        assert mock_select_parent.call_count == 0
        assert mock_study._storage.get_study_system_attrs.call_count == 0
        assert mock_study._get_trials.call_count == 0
        assert return_value == []
        return

    mock_study._storage.get_study_system_attrs.assert_called_once_with(mock_study._study_id)

    if not args["cache"]:
        mock_select_parent.assert_called_once_with(mock_study, args["generation"])
        mock_study._storage.set_study_system_attr.assert_called_once_with(
            mock_study._study_id,
            BaseGASamplerTestSampler._get_parent_cache_key_prefix() + str(args["generation"]),
            [i._trial_id for i in mock_select_parent.return_value],
        )

@codecov
Copy link
Copy Markdown

codecov bot commented May 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 88.32%. Comparing base (7272fb7) to head (d8bf456).
Report is 182 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6066      +/-   ##
==========================================
- Coverage   88.43%   88.32%   -0.11%     
==========================================
  Files         206      207       +1     
  Lines       13910    13978      +68     
==========================================
+ Hits        12301    12346      +45     
- Misses       1609     1632      +23     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@leevers
Copy link
Copy Markdown
Contributor Author

leevers commented May 9, 2025

I have updated the test and got it to a passing state

Comment on lines +219 to +232
optuna.trial.FrozenTrial(
number=i,
trial_id=i,
state=optuna.trial.TrialState.WAITING,
value=None,
datetime_start=None,
datetime_complete=None,
params={},
distributions={},
user_attrs={},
system_attrs={},
intermediate_values={},
values=None,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[NIT]
There’s a test utility function for creating FrozenTrial objects:

def _create_frozen_trial(
number: int = 0,
values: Sequence[float] | None = None,
constraints: Sequence[float] | None = None,
params: dict[str, Any] | None = None,
param_distributions: dict[str, BaseDistribution] | None = None,
state: TrialState = TrialState.COMPLETE,
) -> optuna.trial.FrozenTrial:
return FrozenTrial(
number=number,
value=1.0 if values is None else None,
values=values,
state=state,
user_attrs={},
system_attrs={} if constraints is None else {_CONSTRAINTS_KEY: list(constraints)},
params=params or {},
distributions=param_distributions or {},
intermediate_values={},
datetime_start=None,
datetime_complete=None,
trial_id=number,
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to test this properly we actually need trials that have a different trial_id to their index in the trials array

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a new test case to test this condition?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't know enough about mocking etc to know how

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing the first test case to

        {
            "study_system_attrs": {
                BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "2": [3, 4, 5]
            },
            "parent_population": [3, 4, 5],
            "generation": 3,
            "cache": True,
        },

and mock_study._get_trials to the below?

    if args["cache"]:
        mock_study._get_trials.return_value = [
            optuna.trial.FrozenTrial(
                number=i,
                trial_id=j,
                state=optuna.trial.TrialState.WAITING,
                value=None,
                datetime_start=None,
                datetime_complete=None,
                params={},
                distributions={},
                user_attrs={},
                system_attrs={},
                intermediate_values={},
                values=None,
            )
            for i, j in enumerate(args["study_system_attrs"][
                BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "2"
            ])
        ]

Then, the return value has trials with trial_number != trial_id.

[FrozenTrial(number=0, state=4, values=None, datetime_start=None, datetime_complete=None, params={}, user_attrs={}, system_attrs={}, int
ermediate_values={}, distributions={}, trial_id=3, value=None), FrozenTrial(number=1, state=4, values=None, datetime_start=None, datetime_complete=None, params={}, user_attrs={}, s
ystem_attrs={}, intermediate_values={}, distributions={}, trial_id=4, value=None), FrozenTrial(number=2, state=4, values=None, datetime_start=None, datetime_complete=None, params={
}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={}, trial_id=5, value=None)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since get_parent_population() only accesses the _trial_id attribute, creating full FrozenTrial instances is unnecessary. Simple mocks with just _trial_id are sufficient for this test:

if args["cache"]:
    mock_study._get_trials.return_value = [
        Mock(_trial_id=i)
        for i in args["study_system_attrs"][
            BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "1"
        ]
    ]

It’s a bit more complicated, but for additional safety, it’s useful to enforce that only _trial_id is accessed. This can be done with a strict mock:

if args["cache"]:

    class StrictTrialMock(MagicMock):
        def __getattr__(self, name):
            if name != "_trial_id":
                raise AttributeError(
                    f"Access to attribute '{name}' is not allowed."
                )
            return super().__getattr__(name)

    mock_study._get_trials.return_value = [
        StrictTrialMock(_trial_id=i)
        for i in args["study_system_attrs"][
            BaseGASamplerTestSampler._get_parent_cache_key_prefix() + "1"
        ]
    ]

This setup ensures that get_parent_population() strictly depends only on _trial_id, making the test independent of the trial number or any other attributes.

@HideakiImamura
Copy link
Copy Markdown
Member

@sawa3030 Could you review this PR?

@nabenabe0928 nabenabe0928 changed the title Update _base.py Fix a bug in distributed optimization using NSGA-II/III May 19, 2025
@nabenabe0928 nabenabe0928 added the bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself. label May 19, 2025
@nabenabe0928 nabenabe0928 modified the milestone: v4.4.0 May 19, 2025
leevers and others added 5 commits May 19, 2025 15:39
update based on suggestions from optuna#6066 (comment)
Co-authored-by: Shuhei Watanabe <47781922+nabenabe0928@users.noreply.github.com>
Co-authored-by: Eri Sawada <125344906+sawa3030@users.noreply.github.com>
leevers added 2 commits May 19, 2025 15:55
mypy
tests/samplers_tests/test_base_gasampler.py:220: error: Function is missing a type annotation  [no-untyped-def]
tests/samplers_tests/test_base_gasampler.py:220: error: Function is missing a return type annotation  [no-untyped-def]
mock_study._get_trials.assert_has_calls(
[call(deepcopy=False)] + [call().__getitem__(i) for i in args["parent_population"]]
)
assert return_value == mock_study._get_trials.return_value
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a test to confirm that the cache is used instead of calling select_parent is helpful.

Suggested change
assert return_value == mock_study._get_trials.return_value
assert mock_select_parent.call_count == 0
assert return_value == mock_study._get_trials.return_value

Comment on lines +219 to +223
class StrictTrialMock(MagicMock):
def __getattr__(self, name: str) -> Any:
if name != "_trial_id":
raise AttributeError(f"Access to attribute '{name}' is not allowed.")
return super().__getattr__(name)
Copy link
Copy Markdown
Member

@y0z y0z May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gen740 @leevers

Perhaps, Mock does not support overriding __getattr__.

So, how about simply using MagicMock here to avoid a tricky test?

Actually, I've confirmed a curious behavior, i.e., the name passed to the __getattr__ is replaced with _mock_methods , in the code below, and the current implementation does not work as expected.

        class StrictTrialMock(MagicMock):
            def __getattr__(self, name: str) -> Any:
                print(name)
                if name != "_trial_id":
                    raise AttributeError(f"Access to attribute '{name}' is not allowed.")
                return super().__getattr__(name)
pytest -s tests/samplers_tests/test_base_gasampler.py::test_get_parent_population
=================================================================================== test session starts ====================================================================================
platform darwin -- Python 3.11.7, pytest-7.4.4, pluggy-1.3.0
rootdir: /Users/yozaki/work/review/optuna
configfile: pyproject.toml
plugins: jaxtyping-0.2.30, typeguard-2.13.3
collected 4 items                                                                                                                                                                          

tests/samplers_tests/test_base_gasampler.py _mock_methods
_mock_methods
_mock_methods

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label May 29, 2025
Copy link
Copy Markdown
Member

@gen740 gen740 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@gen740 gen740 removed the stale Exempt from stale bot labeling. label May 30, 2025
Copy link
Copy Markdown
Member

@y0z y0z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changing tests can be done as a follow-up task.
The logic fix is ​​appropriate, so I merge this PR.

LGTM.

@y0z y0z merged commit 8059c74 into optuna:master May 30, 2025
14 checks passed
@y0z y0z added this to the v4.4.0 milestone May 30, 2025
@sawa3030
Copy link
Copy Markdown
Collaborator

Let me work on a follow-up PR to update the tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Issue/PR about behavior that is broken. Not for typos/examples/CI/test but for Optuna itself.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants