Reduce SELECT statements of _CachedStorage.get_all_trials by fixing filtering conditions#5704
Conversation
|
I assigned @porink0424 as an additional reviewer, considering his recent contributions to RDBStorage. I'll proceed with reviewing the changes after @porink0424 has approved this PR. |
porink0424
left a comment
There was a problem hiding this comment.
I've left a few initial comments for now👍 I may add more comments later.
c-bata
left a comment
There was a problem hiding this comment.
Let me leave some early feedback comments.
Do you think we should update the unfinished_trial_ids and last_finished_trial_id in _CachedStorage.set_trial_state_values?
https://github.com/optuna/optuna/pull/5704/files#diff-8ce5c3176c6b5fa3a21ed11da3f19d07c7c0291b7b7b6b67221b057855b7ac50R187
|
I executed a micro-benchmark and confirmed that this PR makes Benchmark Scriptimport optuna
import time
import numpy as np
from concurrent.futures import ThreadPoolExecutor
n_trials = 10000
# Disable calling storage.get_best_trial for logging
optuna.logging.set_verbosity(optuna.logging.WARNING)
def objective(trial: optuna.Trial) -> float:
s = 0.0
for i in range(18):
trial.set_user_attr(f"attr{i}", "attr value")
for i in range(8):
s += trial.suggest_float(f"x{i}", -10, 10) ** 2
return s
storage = optuna.storages.RDBStorage("mysql+pymysql://optuna:password@127.0.0.1:3306/optuna")
start = time.time()
study = optuna.create_study(storage=storage, sampler=optuna.samplers.RandomSampler())
study.optimize(objective, n_trials, n_jobs=10)
print(f"study.optimize elapsed: {time.time() - start}")
start = time.time()
with ThreadPoolExecutor(max_workers=10) as pool:
for i in range(20):
pool.submit(storage.get_all_trials, study._study_id)
print(f"storage.get_all_trials: {time.time() - start}")Benchmark Results
|
| n_trials | master | this PR | diff |
|---|---|---|---|
| 1000 | 76.90139293670654 | 74.6462049484253 | -2.93% |
| 10000 | 1038.7614076137543 | 762.2547173500061 | -26.61% |
| 50000 | 13536.754137277603 | 4220.0592222213745 | -68.82% |
storage.get_all_trials x 20 (10 threads)
| n_trials | master | this PR | diff |
|---|---|---|---|
| 1000 | 24.590099096298218 | 24.788296461105347 | +0.80% |
| 10000 | 237.29632568359375 | 242.50016379356384 | +2.19% |
|
@porink0424 @c-bata Thanks for the review comments. I applied your suggestions. Please take a look.
I don't have a strong opinion, but I think we don't have to update them in (1) The existing |
c-bata
left a comment
There was a problem hiding this comment.
Thank you for the pull request. Changes look almost good to me. I left one minor suggestion though.
| .all() | ||
| ) | ||
| elif trial_id_greater_than > -1: | ||
| _query = query.filter(models.TrialModel.trial_id > trial_id_greater_than) |
There was a problem hiding this comment.
Can we reuse the query variable and remove the else block?
| _query = query.filter(models.TrialModel.trial_id > trial_id_greater_than) | |
| query = query.filter(models.TrialModel.trial_id > trial_id_greater_than) |
There was a problem hiding this comment.
The variable query is also used in the except block in which we don't add the filtering to the SQL query, so I use another name of variable here.
There was a problem hiding this comment.
I see. I noticed another issue and have one concern:
- Currently,
selectinloadis used even in theexceptblock, which is likely unintentional. - Although this isn’t related to your changes, the logic in the
exceptblock isn't tested at all.
What are your thoughts on these?
There was a problem hiding this comment.
Thanks for the comments.
Currently, selectinload is used even in the except block, which is likely unintentional.
I think we have already used selectinload in this except block (ref). Do you think we should avoid this? The processing enters the except block due to the specification of sqlite that the number of maximum allowed variables using IN is 999 as noted in the above comments in codes and the document, so I think it is reasonable to use selectinload here as we do in the other cases.
Although this isn’t related to your changes, the logic in the except block isn't tested at all.
In my understanding, this except block is already tested here. However, this test case is not obvious to understand what it tests.
There was a problem hiding this comment.
I think we have already used selectinload in this except block (ref).
Ah, you are absolutely correct!
In my understanding, this except block is already tested here. However, this test case is not obvious to understand what it tests.
As for unit tests, it seems that the except block isn’t actually tested in the master branch.
$ git co master
$ python3
>>> import optuna
>>> storage = optuna.storages.RDBStorage("sqlite:///test.db")
>>> study_id = storage.create_new_study(directions=[optuna.study.StudyDirection.MINIMIZE])
[I 2024-10-29 18:29:53,124] A new study created in RDB with name: no-name-56f4ff56-e9e8-488a-8c93-d88e312ad97b
>>> storage.create_new_trial(study_id)
1
>>> trials = storage._get_trials(study_id, states=None, excluded_trial_ids=set(range(500000)))
>>>
That said, this PR also addresses the test case, so my concern is effectively resolved. Thank you for your great work!
$ gh pr checkout 5704
$ python3
>>> import optuna
>>> storage = optuna.storages.RDBStorage("sqlite:///test.db")
>>> study_id = storage.create_new_study(directions=[optuna.study.StudyDirection.MINIMIZE])
[I 2024-10-29 18:29:53,124] A new study created in RDB with name: no-name-56f4ff56-e9e8-488a-8c93-d88e312ad97b
>>> trial_id = storage.create_new_trial(study_id)
>>> trial_id_greater_than = trial_id + 500000
>>> trials = storage._get_trials(
... study_id,
... states=None,
... included_trial_ids=set(range(500000)),
... trial_id_greater_than=trial_id_greater_than,
... )
:
(Background on this error at: https://sqlalche.me/e/20/e3q8). Falling back to a slower alternative.
porink0424
left a comment
There was a problem hiding this comment.
LGTM, as long as the comment above from @c-bata is addressed.
Motivation
This PR aims to reduce the number of
SELECTstatements of_CachedStorage.get_all_trials. The current filtering conditions are removing excluded trials inRDBStorage._get_trials, but it can be simplified by using "included" trials.Description of the changes
trial_id_cursor