Skip to content

Document Behavior of optuna.pruners.MedianPruner and optuna.pruners.PatientPruner #6055

Merged
gen740 merged 9 commits intooptuna:masterfrom
ParagEkbote:Document-Pruner-Nan-Values-Behaviour
May 9, 2025
Merged

Document Behavior of optuna.pruners.MedianPruner and optuna.pruners.PatientPruner #6055
gen740 merged 9 commits intooptuna:masterfrom
ParagEkbote:Document-Pruner-Nan-Values-Behaviour

Conversation

@ParagEkbote
Copy link
Copy Markdown
Contributor

Motivation

Refs #5202

I have added documentation of how the Median Pruner and Patient Pruner handles Nan values.

It is important to note that we are using synthetic dataset for testing.

Could you please review?

cc: @HideakiImamura

Description of the changes

I have used the following scripts to test it out:

Median Pruner:

import optuna
import numpy as np
import logging
import sys

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def nan_objective(trial):
    """An objective function that returns NaN for every 3rd trial."""
    trial_number = trial.number
    
        if trial_number % 3 == 0:
        logger.info(f"Trial {trial_number}: Returning NaN")
        return float('nan')
    else:
        value = trial.suggest_float('x', -10, 10)
        result = value ** 2  # Simple quadratic function
        logger.info(f"Trial {trial_number}: Returning {result}")
        return result

def partial_nan_objective(trial):
    """An objective function where some intermediate values are NaN."""
    x = trial.suggest_float('x', -10, 10)
    
    # Report some intermediate values
    for step in range(10):
        value = x ** 2 + step
        
        # Make every 3rd step return NaN
        if step % 3 == 0:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting NaN")
            trial.report(float('nan'), step)
        else:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting {value}")
            trial.report(value, step)
            
        # Check if trial should be pruned
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned at step {step}")
            raise optuna.exceptions.TrialPruned()
    
    return x ** 2

def test_median_pruner_with_nan_values():
    """Test how MedianPruner handles NaN values in different configurations."""
    
    # Test 1: Basic study with NaN return values
    logger.info("\n=== Test 1: Basic study with NaN final values ===")
    study1 = optuna.create_study(
        pruner=optuna.pruners.MedianPruner(n_startup_trials=0, n_warmup_steps=0),
        direction="minimize"
    )
    study1.optimize(nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials:")
    for trial in study1.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
    
    # Test 2: Study with NaN intermediate values
    logger.info("\n=== Test 2: Study with NaN intermediate values ===")
    study2 = optuna.create_study(
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2),
        direction="minimize"
    )
    study2.optimize(partial_nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials with intermediate values:")
    for trial in study2.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
        if trial.intermediate_values:
            logger.info(f"  Intermediate values: {trial.intermediate_values}")
    
    # Test 3: More comprehensive pruning with NaN values
    logger.info("\n=== Test 3: Comprehensive pruning with NaN values ===")
    study3 = optuna.create_study(
        pruner=optuna.pruners.MedianPruner(n_startup_trials=0, n_warmup_steps=0, interval_steps=1),
        direction="minimize"
    )
    study3.optimize(partial_nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials with aggressive pruning:")
    for trial in study3.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
        if trial.intermediate_values:
            logger.info(f"  Intermediate values: {trial.intermediate_values}")

if __name__ == "__main__":
    test_median_pruner_with_nan_values()

Patient Pruner:

import optuna
import numpy as np
import logging
import sys

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def nan_objective(trial):
    """An objective function that returns NaN for every 3rd trial."""
    trial_number = trial.number
    
    # Every 3rd trial returns NaN
    if trial_number % 3 == 0:
        logger.info(f"Trial {trial_number}: Returning NaN")
        return float('nan')
    else:
        value = trial.suggest_float('x', -10, 10)
        result = value ** 2  # Simple quadratic function
        logger.info(f"Trial {trial_number}: Returning {result}")
        return result

def partial_nan_objective(trial):
    """An objective function where some intermediate values are NaN."""
    x = trial.suggest_float('x', -10, 10)
    
    # Report some intermediate values
    for step in range(10):
        value = x ** 2 + step
        
        # Make every 3rd step return NaN
        if step % 3 == 0:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting NaN")
            trial.report(float('nan'), step)
        else:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting {value}")
            trial.report(value, step)
            
        # Check if trial should be pruned
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned at step {step}")
            raise optuna.exceptions.TrialPruned()
    
    return x ** 2

def test_patient_pruner_with_nan_values():
    """Test how PatientPruner handles NaN values in different configurations."""
    
    # Create a MedianPruner to be used as base pruner for PatientPruner
    median_pruner = optuna.pruners.MedianPruner(n_startup_trials=0, n_warmup_steps=0)
    
    # Test 1: Basic study with NaN return values
    logger.info("\n=== Test 1: Basic study with NaN final values ===")
    study1 = optuna.create_study(
        pruner=optuna.pruners.PatientPruner(median_pruner, patience=2),
        direction="minimize"
    )
    study1.optimize(nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials:")
    for trial in study1.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
    
    # Test 2: Study with NaN intermediate values
    logger.info("\n=== Test 2: Study with NaN intermediate values ===")
    study2 = optuna.create_study(
        pruner=optuna.pruners.PatientPruner(median_pruner, patience=2),
        direction="minimize"
    )
    study2.optimize(partial_nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials with intermediate values:")
    for trial in study2.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
        if trial.intermediate_values:
            logger.info(f"  Intermediate values: {trial.intermediate_values}")
    
    # Test 3: Zero patience with NaN values
    logger.info("\n=== Test 3: Zero patience with NaN values ===")
    study3 = optuna.create_study(
        pruner=optuna.pruners.PatientPruner(median_pruner, patience=0),
        direction="minimize"
    )
    study3.optimize(partial_nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials with zero patience:")
    for trial in study3.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
        if trial.intermediate_values:
            logger.info(f"  Intermediate values: {trial.intermediate_values}")

if __name__ == "__main__":
    test_patient_pruner_with_nan_values()

@HideakiImamura
Copy link
Copy Markdown
Member

@gen740 @nabenabe0928 @sawa3030 Could you review this PR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the implementation here, if any intermediate value in a trial is nan, maybe_prune becomes nan, so pruning will not happen.

scores_before_patience = np.asarray(
list(intermediate_values[step] for step in steps_before_patience)
)
# And these are the scores after that
steps_after_patience = steps[-self._patience - 1 :]
scores_after_patience = np.asarray(
list(intermediate_values[step] for step in steps_after_patience)
)
direction = study.direction
if direction == StudyDirection.MINIMIZE:
maybe_prune = np.nanmin(scores_before_patience) + self._min_delta < np.nanmin(
scores_after_patience
)
else:
maybe_prune = np.nanmax(scores_before_patience) - self._min_delta > np.nanmax(
scores_after_patience
)
if maybe_prune:
if self._wrapped_pruner is not None:
return self._wrapped_pruner.prune(study, trial)
else:
return True
else:
return False

Note that scores_{before, after}_patience will not be empty due to the if-statement here:

if steps.size <= self._patience + 1:
return False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, if either scores_before_patience or scores_after_patience consists entirely of nan values, maybe_prune becomes False and pruning will not occur.

Copy link
Copy Markdown
Contributor Author

@ParagEkbote ParagEkbote Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sawa3030 Is there any way we could test your idea?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay. I believe we can test this behavior using the following examples.

1. When scores_before_patience is all NaN

If we set patience=5, and report NaN for the first 5 steps (for trials after the initial one), scores_before_patience becomes all NaN. This causes maybe_prune to return False.

import optuna
import numpy as np
import logging
import sys

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def partial_nan_objective(trial):
    """An objective function where some intermediate values are NaN."""
    x = trial.suggest_float('x', -10, 10)
    trial_number = trial.number
    
    # Report some intermediate values
    for step in range(10):
        value = x ** 2 + step
        
        # Make the first 5 steps NaN
        if step < 5 and trial_number > 1:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting NaN")
            trial.report(float('nan'), step)
        else:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting {value}")
            trial.report(value, step)
            
        # Check if trial should be pruned
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned at step {step}")
            raise optuna.exceptions.TrialPruned()
    
    return x ** 2

def test_patient_pruner_with_nan_values():
    """Test how PatientPruner handles NaN values in different configurations."""
    
    # Create a MedianPruner to be used as base pruner for PatientPruner
    median_pruner = optuna.pruners.MedianPruner(n_startup_trials=0, n_warmup_steps=0)
    
    study = optuna.create_study(
        pruner=optuna.pruners.PatientPruner(median_pruner, patience=5),
        direction="minimize"
    )
    study.optimize(partial_nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials with intermediate values:")
    for trial in study.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
        if trial.intermediate_values:
            logger.info(f"  Intermediate values: {trial.intermediate_values}")
    
if __name__ == "__main__":
    test_patient_pruner_with_nan_values()

2. When scores_after_patience is all NaN

With patience=0 and returning only NaNs from step 1 onward, scores_after_patience becomes all NaN. This causes maybe_prune to return False.

import optuna
import numpy as np
import logging
import sys

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def partial_nan_objective(trial):
    """An objective function where some intermediate values are NaN."""
    x = trial.suggest_float('x', -10, 10)
    trial_number = trial.number
    
    # Report some intermediate values
    for step in range(10):
        value = x ** 2 + step
        
        # Make some values NaN based on certain conditions
        if step > 0 and trial_number > 1:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting NaN")
            trial.report(float('nan'), step)
        else:
            logger.info(f"Trial {trial.number}, Step {step}: Reporting {value}")
            trial.report(value, step)
            
        # Check if trial should be pruned
        if trial.should_prune():
            logger.info(f"Trial {trial.number} pruned at step {step}")
            raise optuna.exceptions.TrialPruned()
    
    return x ** 2

def test_patient_pruner_with_nan_values():
    """Test how PatientPruner handles NaN values in different configurations."""
    
    # Create a MedianPruner to be used as base pruner for PatientPruner
    median_pruner = optuna.pruners.MedianPruner(n_startup_trials=0, n_warmup_steps=0)

    study = optuna.create_study(
        pruner=optuna.pruners.PatientPruner(median_pruner, patience=0),
        direction="minimize"
    )
    study.optimize(partial_nan_objective, n_trials=10)
    
    # Print results
    logger.info("Completed trials with intermediate values:")
    for trial in study.trials:
        logger.info(f"Trial {trial.number}: State={trial.state}, Value={trial.value}")
        if trial.intermediate_values:
            logger.info(f"  Intermediate values: {trial.intermediate_values}")
    
if __name__ == "__main__":
    test_patient_pruner_with_nan_values()

@nabenabe0928 nabenabe0928 added the document Documentation related. label Apr 23, 2025
@ParagEkbote
Copy link
Copy Markdown
Contributor Author

I have addressed the comments and updated accordingly. Could you please review?

cc: @nabenabe0928

Co-authored-by: Shuhei Watanabe <47781922+nabenabe0928@users.noreply.github.com>
@ParagEkbote
Copy link
Copy Markdown
Contributor Author

Could you please review the changes?

cc: @HideakiImamura

@ParagEkbote ParagEkbote requested a review from nabenabe0928 May 1, 2025 17:16
Copy link
Copy Markdown
Contributor

@nabenabe0928 nabenabe0928 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updates, LGTM!

@nabenabe0928 nabenabe0928 removed their assignment May 2, 2025
ParagEkbote and others added 3 commits May 2, 2025 18:11
Co-authored-by: Eri Sawada <125344906+sawa3030@users.noreply.github.com>
@ParagEkbote
Copy link
Copy Markdown
Contributor Author

ParagEkbote commented May 2, 2025

I've updated the documentation and all the checks are passing.

Could you please review the changes ?

cc: @sawa3030

Copy link
Copy Markdown
Collaborator

@sawa3030 sawa3030 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update. LGTM

Copy link
Copy Markdown
Member

@gen740 gen740 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@gen740 gen740 added this to the v4.4.0 milestone May 9, 2025
@gen740 gen740 merged commit 2832e92 into optuna:master May 9, 2025
14 checks passed
@ParagEkbote ParagEkbote deleted the Document-Pruner-Nan-Values-Behaviour branch May 10, 2025 18:12
@gen740 gen740 removed their assignment May 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

document Documentation related.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants