Document Behavior of optuna.pruners.MedianPruner and optuna.pruners.PatientPruner #6055
Conversation
|
@gen740 @nabenabe0928 @sawa3030 Could you review this PR? |
There was a problem hiding this comment.
According to the implementation here, if any intermediate value in a trial is nan, maybe_prune becomes nan, so pruning will not happen.
optuna/optuna/pruners/_patient.py
Lines 96 to 121 in d01faa2
Note that scores_{before, after}_patience will not be empty due to the if-statement here:
optuna/optuna/pruners/_patient.py
Lines 90 to 91 in d01faa2
There was a problem hiding this comment.
In my understanding, if either scores_before_patience or scores_after_patience consists entirely of nan values, maybe_prune becomes False and pruning will not occur.
There was a problem hiding this comment.
@sawa3030 Is there any way we could test your idea?
There was a problem hiding this comment.
Apologies for the delay. I believe we can test this behavior using the following examples.
1. When scores_before_patience is all NaN
If we set patience=5, and report NaN for the first 5 steps (for trials after the initial one), scores_before_patience becomes all NaN. This causes maybe_prune to return False.
import optuna
import numpy as np
import logging
import sys
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def partial_nan_objective(trial):
"""An objective function where some intermediate values are NaN."""
x = trial.suggest_float('x', -10, 10)
trial_number = trial.number
# Report some intermediate values
for step in range(10):
value = x ** 2 + step
# Make the first 5 steps NaN
if step < 5 and trial_number > 1:
logger.info(f"Trial {trial.number}, Step {step}: Reporting NaN")
trial.report(float('nan'), step)
else:
logger.info(f"Trial {trial.number}, Step {step}: Reporting {value}")
trial.report(value, step)
# Check if trial should be pruned
if trial.should_prune():
logger.info(f"Trial {trial.number} pruned at step {step}")
raise optuna.exceptions.TrialPruned()
return x ** 2
def test_patient_pruner_with_nan_values():
"""Test how PatientPruner handles NaN values in different configurations."""
# Create a MedianPruner to be used as base pruner for PatientPruner
median_pruner = optuna.pruners.MedianPruner(n_startup_trials=0, n_warmup_steps=0)
study = optuna.create_study(
pruner=optuna.pruners.PatientPruner(median_pruner, patience=5),
direction="minimize"
)
study.optimize(partial_nan_objective, n_trials=10)
# Print results
logger.info("Completed trials with intermediate values:")
for trial in study.trials:
logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
if trial.intermediate_values:
logger.info(f" Intermediate values: {trial.intermediate_values}")
if __name__ == "__main__":
test_patient_pruner_with_nan_values()
2. When scores_after_patience is all NaN
With patience=0 and returning only NaNs from step 1 onward, scores_after_patience becomes all NaN. This causes maybe_prune to return False.
import optuna
import numpy as np
import logging
import sys
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def partial_nan_objective(trial):
"""An objective function where some intermediate values are NaN."""
x = trial.suggest_float('x', -10, 10)
trial_number = trial.number
# Report some intermediate values
for step in range(10):
value = x ** 2 + step
# Make some values NaN based on certain conditions
if step > 0 and trial_number > 1:
logger.info(f"Trial {trial.number}, Step {step}: Reporting NaN")
trial.report(float('nan'), step)
else:
logger.info(f"Trial {trial.number}, Step {step}: Reporting {value}")
trial.report(value, step)
# Check if trial should be pruned
if trial.should_prune():
logger.info(f"Trial {trial.number} pruned at step {step}")
raise optuna.exceptions.TrialPruned()
return x ** 2
def test_patient_pruner_with_nan_values():
"""Test how PatientPruner handles NaN values in different configurations."""
# Create a MedianPruner to be used as base pruner for PatientPruner
median_pruner = optuna.pruners.MedianPruner(n_startup_trials=0, n_warmup_steps=0)
study = optuna.create_study(
pruner=optuna.pruners.PatientPruner(median_pruner, patience=0),
direction="minimize"
)
study.optimize(partial_nan_objective, n_trials=10)
# Print results
logger.info("Completed trials with intermediate values:")
for trial in study.trials:
logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
if trial.intermediate_values:
logger.info(f" Intermediate values: {trial.intermediate_values}")
if __name__ == "__main__":
test_patient_pruner_with_nan_values()
|
I have addressed the comments and updated accordingly. Could you please review? cc: @nabenabe0928 |
Co-authored-by: Shuhei Watanabe <47781922+nabenabe0928@users.noreply.github.com>
|
Could you please review the changes? cc: @HideakiImamura |
nabenabe0928
left a comment
There was a problem hiding this comment.
Thank you for the updates, LGTM!
Co-authored-by: Eri Sawada <125344906+sawa3030@users.noreply.github.com>
|
I've updated the documentation and all the checks are passing. Could you please review the changes ? cc: @sawa3030 |
sawa3030
left a comment
There was a problem hiding this comment.
Thank you for the update. LGTM
Motivation
Refs #5202
I have added documentation of how the Median Pruner and Patient Pruner handles Nan values.
It is important to note that we are using synthetic dataset for testing.
Could you please review?
cc: @HideakiImamura
Description of the changes
I have used the following scripts to test it out:
Median Pruner:
Patient Pruner: