-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
nan handling in Optuna #3205
Description
A new version of #2169 adapted to the latest Optuna.
And related to #3132.
[note: fixed 2022/01/30]
Expected behavior
Regardless of the type of storage used, mark the corresponding trial as FAILED when an objective function returns nan, when nan is passed in trial.report, or when nan is passed in study.tell.
Regardless of the type of storage used, mark the corresponding trial as FAILED when an objective function returns nan or when nan is passed in study.tell.
On the other hand, regardless of the type of storage used, accept the reported nan when nan is reported with trial.report.
We need to fix the following.
-
IfIfnanis passed intrial.report, mark the trial asFAILEDregardless of the type of storage used.nanis passed intrial.report, accept the reportednanregardless of the type of storage used. Acceptnanintrial.report#3348 -
Add document for workaround when passingnanintrial.report. In a trial, if the intermediate value isnanat the beginning of the calculation, but it reverts to the appropriate value in the middle, the user can check if it is notnanbeforetrial.report. - If
nanis passed instudy.tell, mark the trial asFAILEDregardless of the type of storage used. Move validation logic from_run_trialtostudy.tell#3144
Environment
- Optuna version: a70cbff
- Python version: 3.9.9
- OS: macOS Catalina 10.15.7
- (Optional) Other libraries and their versions:
psycopg2: 2.9.2
PyMySQL: 1.0.2
Error messages, stack traces, or logs
If we run the script nan-objective.py, we will get the following output. The script is provided in the section of Reproducible examples. I think this output is ideal and does not need to be modified. We should mark the trial as FAILED for all storages.
output
(venv) mamu@HideakinoMacBook-puro test % python nan-objective.py
[I 2021-12-24 13:36:27,998] A new study created in memory with name: no-name-769e5aec-62ec-4719-ac85-6a6fbf375bee
[W 2021-12-24 13:36:28,002] Trial 0 failed, because the objective function returned nan.
In-memory : Trial state is TrialState.FAIL
[I 2021-12-24 13:36:28,086] A new study created in RDB with name: no-name-8d701653-5818-4c28-9a8f-0e38d8f9019c
[W 2021-12-24 13:36:28,135] Trial 0 failed, because the objective function returned nan.
SQLite : Trial state is TrialState.FAIL
[I 2021-12-24 13:36:28,295] A new study created in RDB with name: no-name-4367a62b-fe90-4cd0-89e8-74c18e272f21
[W 2021-12-24 13:36:28,403] Trial 0 failed, because the objective function returned nan.
MySQL : Trial state is TrialState.FAIL
[I 2021-12-24 13:36:28,722] A new study created in RDB with name: no-name-8cc33f98-2f7b-41ef-802e-8d46abb6e266
[W 2021-12-24 13:36:28,850] Trial 0 failed, because the objective function returned nan.
PostgreSQL: Trial state is TrialState.FAIL
On the other hand, if we run the script nan-report.py, we will get the following output. The script is provided in the section of Reproducible examples. I think we need to modify this behavior as explained in the section of Expected behavior.
output
(venv) mamu@HideakinoMacBook-puro test % python nan-report.py
[I 2021-12-24 13:39:15,478] A new study created in memory with name: no-name-f34999a1-b56e-45d5-894a-a18db92956f6
[I 2021-12-24 13:39:15,479] Trial 0 finished with value: 1.0 and parameters: {}. Best is trial 0 with value: 1.0.
In-memory : Trial state is TrialState.COMPLETE
[I 2021-12-24 13:39:15,603] A new study created in RDB with name: no-name-aa0bdcc0-dd7c-4c19-8bf3-df7451bab2eb
[I 2021-12-24 13:39:15,704] Trial 0 finished with value: 1.0 and parameters: {}. Best is trial 0 with value: 1.0.
SQLite : Trial state is TrialState.COMPLETE
[I 2021-12-24 13:39:15,878] A new study created in RDB with name: no-name-96a37aca-68a4-48d2-83b5-051be816f138
[W 2021-12-24 13:39:16,019] Trial 0 failed because of the following error: StorageInternalError('An exception is raised during the commit. This typically happens due to invalid data in the commit, e.g. exceeding max length. ')
Traceback (most recent call last):
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1802, in _execute_context
self.dialect.do_execute(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/default.py", line 732, in do_execute
cursor.execute(statement, parameters)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 146, in execute
query = self.mogrify(query, args)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 125, in mogrify
query = query % self._escape_args(args, conn)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 109, in _escape_args
return {key: conn.literal(val) for (key, val) in args.items()}
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 109, in <dictcomp>
return {key: conn.literal(val) for (key, val) in args.items()}
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/connections.py", line 517, in literal
return self.escape(obj, self.encoders)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/connections.py", line 510, in escape
return converters.escape_item(obj, self.charset, mapping=mapping)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/converters.py", line 25, in escape_item
val = encoder(val, mapping)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/converters.py", line 60, in escape_float
raise ProgrammingError("%s can not be used with MySQL" % s)
pymysql.err.ProgrammingError: nan can not be used with MySQL
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/optuna/storages/_rdb/storage.py", line 54, in _create_scoped_session
session.commit()
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/session.py", line 1431, in commit
self._transaction.commit(_to_root=self.future)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/session.py", line 829, in commit
self._prepare_impl()
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/session.py", line 808, in _prepare_impl
self.session.flush()
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/session.py", line 3363, in flush
self._flush(objects)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/session.py", line 3503, in _flush
transaction.rollback(_capture_exception=True)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/util/langhelpers.py", line 70, in __exit__
compat.raise_(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/util/compat.py", line 207, in raise_
raise exception
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/session.py", line 3463, in _flush
flush_context.execute()
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/unitofwork.py", line 456, in execute
rec.execute(self)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/unitofwork.py", line 630, in execute
util.preloaded.orm_persistence.save_obj(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/persistence.py", line 244, in save_obj
_emit_insert_statements(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/orm/persistence.py", line 1221, in _emit_insert_statements
result = connection._execute_20(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1614, in _execute_20
return meth(self, args_10style, kwargs_10style, execution_options)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/sql/elements.py", line 325, in _execute_on_connection
return connection._execute_clauseelement(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1481, in _execute_clauseelement
ret = self._execute_context(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1845, in _execute_context
self._handle_dbapi_exception(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 2026, in _handle_dbapi_exception
util.raise_(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/util/compat.py", line 207, in raise_
raise exception
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1802, in _execute_context
self.dialect.do_execute(
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/sqlalchemy/engine/default.py", line 732, in do_execute
cursor.execute(statement, parameters)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 146, in execute
query = self.mogrify(query, args)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 125, in mogrify
query = query % self._escape_args(args, conn)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 109, in _escape_args
return {key: conn.literal(val) for (key, val) in args.items()}
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/cursors.py", line 109, in <dictcomp>
return {key: conn.literal(val) for (key, val) in args.items()}
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/connections.py", line 517, in literal
return self.escape(obj, self.encoders)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/connections.py", line 510, in escape
return converters.escape_item(obj, self.charset, mapping=mapping)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/converters.py", line 25, in escape_item
val = encoder(val, mapping)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/pymysql/converters.py", line 60, in escape_float
raise ProgrammingError("%s can not be used with MySQL" % s)
sqlalchemy.exc.ProgrammingError: (pymysql.err.ProgrammingError) nan can not be used with MySQL
[SQL: INSERT INTO trial_intermediate_values (trial_id, step, intermediate_value) VALUES (%(trial_id)s, %(step)s, %(intermediate_value)s)]
[parameters: {'trial_id': 7, 'step': 1, 'intermediate_value': nan}]
(Background on this error at: https://sqlalche.me/e/14/f405)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/optuna/study/_optimize.py", line 213, in _run_trial
value_or_values = func(trial)
File "/Users/mamu/Dev/pfn/code_snippets_for_optuna/test/nan-report.py", line 5, in objective
trial.report(float(np.nan), 1)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/optuna/trial/_trial.py", line 506, in report
self.storage.set_trial_intermediate_value(self._trial_id, step, value)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/optuna/storages/_cached_storage.py", line 293, in set_trial_intermediate_value
self._backend.set_trial_intermediate_value(trial_id, step, intermediate_value)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/optuna/storages/_rdb/storage.py", line 754, in set_trial_intermediate_value
self._set_trial_intermediate_value_without_commit(
File "/Users/mamu/.pyenv/versions/3.9.9/lib/python3.9/contextlib.py", line 126, in __exit__
next(self.gen)
File "/Users/mamu/Dev/pfn/venv/lib/python3.9/site-packages/optuna/storages/_rdb/storage.py", line 71, in _create_scoped_session
raise optuna.exceptions.StorageInternalError(message) from e
optuna.exceptions.StorageInternalError: An exception is raised during the commit. This typically happens due to invalid data in the commit, e.g. exceeding max length.
An exception is raised during the commit. This typically happens due to invalid data in the commit, e.g. exceeding max length.
MySQL : Trial state is TrialState.FAIL
[I 2021-12-24 13:39:16,341] A new study created in RDB with name: no-name-1d3ad4c8-ae5b-4da8-aafb-5139770192a3
[I 2021-12-24 13:39:16,519] Trial 0 finished with value: 1.0 and parameters: {}. Best is trial 0 with value: 1.0.
PostgreSQL: Trial state is TrialState.COMPLETE
In addition, if we run the script nan-tell.py, we will get the following output. The script is provided in the section of Reproducible examples. I think we need to modify this behavior as explained in the section of Expected behavior.
output
(venv) mamu@HideakinoMacBook-puro 3205-3206 % python nan-tell.py
[I 2021-12-24 14:57:53,054] A new study created in memory with name: no-name-0dab10fe-1de5-4ec6-8a17-145f92fdf286
Trial 0 failed, because the objective function returned nan.
In-memory : Trial state is TrialState.RUNNING
[I 2021-12-24 14:57:53,138] A new study created in RDB with name: no-name-7757614d-b4bc-43ba-9489-1e691ae1ad3d
Trial 0 failed, because the objective function returned nan.
SQLite : Trial state is TrialState.RUNNING
[I 2021-12-24 14:57:53,308] A new study created in RDB with name: no-name-a7651b7d-0b60-47e6-96c9-2c5563819437
Trial 0 failed, because the objective function returned nan.
MySQL : Trial state is TrialState.RUNNING
[I 2021-12-24 14:57:53,659] A new study created in RDB with name: no-name-831368be-d69a-4aeb-a00b-76818043fac1
Trial 0 failed, because the objective function returned nan.
PostgreSQL: Trial state is TrialState.RUNNING
Steps to reproduce
- Install docker
- Setup MySQL DB
docker run --name mysql -e MYSQL_ROOT_PASSWORD=test -p 3306:3306 -p 33060:33060 -d mysql:8.0.22
docker run --network host -it --rm mysql:8.0.22 mysql -h 127.0.0.1 -uroot -ptest -e "create database mysql;"
- Setup PostgreSQL DB
docker run -it --rm --name postgres-test -e POSTGRES_PASSWORD=test -p 15432:5432 -d postgres
- Run
nan-objectibe.py,nan-report.pyandnan-tell.py.
Reproducible examples (optional)
nan-objective.py
import optuna
import numpy as np
def objective(trial):
return float(np.nan)
# In-memory storage
study = optuna.create_study()
study.optimize(objective, n_trials=1)
print(f"In-memory : Trial state is {study.trials[-1].state}")
# RDB storage (SQLite)
study = optuna.create_study(storage='sqlite:///../sqlite.db')
study.optimize(objective, n_trials=1)
print(f"SQLite : Trial state is {study.trials[-1].state}")
# RDB storage (MySQL)
study = optuna.create_study(storage="mysql+pymysql://root:test@localhost/mysql")
study.optimize(objective, n_trials=1)
print(f"MySQL : Trial state is {study.trials[-1].state}")
# RDB storage (PostgreSQL)
study = optuna.create_study(storage="postgresql+psycopg2://postgres:test@localhost:15432/postgres")
study.optimize(objective, n_trials=1)
print(f"PostgreSQL: Trial state is {study.trials[-1].state}")nan-report.py
import optuna
import numpy as np
def objective(trial):
trial.report(float(np.nan), 1)
return 1
# In-memory storage
study = optuna.create_study()
study.optimize(objective, n_trials=1)
print(f"In-memory : Trial state is {study.trials[-1].state}")
# RDB storage (SQLite)
study = optuna.create_study(storage='sqlite:///../sqlite.db')
study.optimize(objective, n_trials=1)
print(f"SQLite : Trial state is {study.trials[-1].state}")
# RDB storage (MySQL)
study = optuna.create_study(storage="mysql+pymysql://root:test@localhost/mysql")
try:
study.optimize(objective, n_trials=1)
except Exception as e:
print(e)
finally:
print(f"MySQL : Trial state is {study.trials[-1].state}")
# RDB storage (PostgreSQL)
study = optuna.create_study(storage="postgresql+psycopg2://postgres:test@localhost:15432/postgres")
study.optimize(objective, n_trials=1)
print(f"PostgreSQL: Trial state is {study.trials[-1].state}")nan-tell.py
import optuna
import numpy as np
# In-memory storage
study = optuna.create_study()
trial = study.ask()
try:
study.tell(trial, float(np.nan))
except Exception as e:
print(e)
finally:
print(f"In-memory : Trial state is {study.trials[-1].state}")
# RDB storage (SQLite)
study = optuna.create_study(storage='sqlite:///../sqlite.db')
trial = study.ask()
try:
study.tell(trial, float(np.nan))
except Exception as e:
print(e)
finally:
print(f"SQLite : Trial state is {study.trials[-1].state}")
# RDB storage (MySQL)
study = optuna.create_study(storage="mysql+pymysql://root:test@localhost/mysql")
trial = study.ask()
try:
study.tell(trial, float(np.nan))
except Exception as e:
print(e)
finally:
print(f"MySQL : Trial state is {study.trials[-1].state}")
# RDB storage (PostgreSQL)
study = optuna.create_study(storage="postgresql+psycopg2://postgres:test@localhost:15432/postgres")
trial = study.ask()
try:
study.tell(trial, float(np.nan))
except Exception as e:
print(e)
finally:
print(f"PostgreSQL: Trial state is {study.trials[-1].state}")