Refactor acquisition function minimally#6166
Conversation
|
@sawa3030 @contramundum53 |
There was a problem hiding this comment.
Pull Request Overview
A refactor to introduce a class-based abstraction for acquisition functions in GPSampler and remove the old dataclass/API.
- Replace the
AcquisitionFunctionParamsand free-function API withBaseAcquisitionFuncsubclasses (LogEI,LogPI,UCB,LCB,ConstrainedLogEI,LogEHVI). - Update all samplers, terminators, tests, and optimization helpers to use the new class-based API.
- Extend
GPRegressorwith internal caching (_cache_matrix) and simplify itskernel/posteriorinterfaces.
Reviewed Changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| tests/samplers_tests/test_gp.py | Updated import and use of LogEI class. |
| tests/gp_tests/test_gp.py | Adjusted GPRegressor instantiation to new signature. |
| tests/gp_tests/test_acqf.py | Switched tests to the new BaseAcquisitionFunc API. |
| optuna/terminator/improvement/evaluator.py | Replaced create_acqf_params with class-based calls. |
| optuna/terminator/improvement/emmr.py | Removed old acqf_params code; use gpr.posterior. |
| optuna/samplers/_gp/sampler.py | Refactored sampler methods to accept BaseAcquisitionFunc. |
| optuna/_gp/optim_sample.py | Deleted obsolete optimize_acqf_sample. |
| optuna/_gp/optim_mixed.py | Updated helpers to use BaseAcquisitionFunc. |
| optuna/_gp/gp.py | Added caching, revised kernel & posterior API. |
| optuna/_gp/acqf.py | Implemented BaseAcquisitionFunc and subclasses. |
Comments suppressed due to low confidence (2)
optuna/samplers/_gp/sampler.py:232
- Newly constructed GPRegressor instances in
constraints_gprslack a call to_cache_matrix(), so theirposteriorwill assert on missing cached matrices. Callgpr._cache_matrix()right after instantiation before appending.
)
optuna/_gp/acqf.py:231
- The
stabilizing_noiseannotation appears in the function body instead of the__init__signature, causing a syntax error. Move it into the parameter list and remove the misplaced closing parenthesis.
stabilizing_noise: float = 1e-12,
import optuna
def objective(trial: optuna.Trial) -> float:
x = trial.suggest_float("x", -5, 5)
y = trial.suggest_float("y", -5, 5)
return x**2 + y**2
def multi_objective(trial: optuna.Trial) -> tuple[float, float]:
x = trial.suggest_float("x", -5, 5)
y = trial.suggest_float("y", -5, 5)
return x**2 + y**2, (x - 2)**2 + (y - 2)**2
def constraints(trial: optuna.trial.FrozenTrial) -> tuple[float, float]:
x = trial.params["x"]
y = trial.params["y"]
return (x - 2, y - 2)
mode = ["single", "multi", "constr"][2]
if mode == "single":
study = optuna.create_study(sampler=optuna.samplers.GPSampler(seed=0))
obj_func = objective
elif mode == "multi":
study = optuna.create_study(sampler=optuna.samplers.GPSampler(seed=0), directions=["minimize"]*2)
obj_func = multi_objective
elif mode == "constr":
study = optuna.create_study(sampler=optuna.samplers.GPSampler(seed=0, constraints_func=constraints))
obj_func = objective
study.optimize(obj_func, n_trials=20)
trials = study.trials
print((trials[-1].datetime_complete - trials[0].datetime_start).total_seconds())The benchmarking code to check the reproducibility. |
| def logpi(mean: torch.Tensor, var: torch.Tensor, f0: float) -> torch.Tensor: | ||
| # Return the integral of N(mean, var) from -inf to f0 | ||
| # This is identical to the integral of N(0, 1) from -inf to (f0-mean)/sigma | ||
| # Return E_{y ~ N(mean, var)}[bool(y <= f0)] | ||
| sigma = torch.sqrt(var) | ||
| return torch.special.log_ndtr((f0 - mean) / sigma) | ||
|
|
||
|
|
||
| def ucb(mean: torch.Tensor, var: torch.Tensor, beta: float) -> torch.Tensor: | ||
| return mean + torch.sqrt(beta * var) | ||
|
|
||
|
|
||
| def lcb(mean: torch.Tensor, var: torch.Tensor, beta: float) -> torch.Tensor: | ||
| return mean - torch.sqrt(beta * var) |
There was a problem hiding this comment.
[note] Reviewed this part
sawa3030
left a comment
There was a problem hiding this comment.
These are my notes to help with understanding and to keep track of what to review.
| # TODO(contramundum53): consider abstraction for acquisition functions. | ||
| # NOTE: Acquisition function is not class on purpose to integrate numba in the future. | ||
| class AcquisitionFunctionType(IntEnum): | ||
| LOG_EI = 0 | ||
| UCB = 1 | ||
| LCB = 2 | ||
| LOG_PI = 3 | ||
| LOG_EHVI = 4 |
There was a problem hiding this comment.
[note]: Reviewed this part
| assert ( | ||
| self._cov_Y_Y_inv is not None and self._cov_Y_Y_inv_Y is not None | ||
| ), "Call cache_matrix before calling posterior." | ||
| cov_fx_fX = self.kernel(x[..., None, :], self._X_train)[..., 0, :] |
There was a problem hiding this comment.
[TODO]: Check if X equals to self._X_train
| # We apply the cholesky decomposition to efficiently compute log(|C|) and C^-1. | ||
|
|
||
| cov_fX_fX = self.kernel(is_categorical, X, X) | ||
| cov_fX_fX = self.kernel(self._X_train, self._X_train) |
There was a problem hiding this comment.
[TODO]: Check if X equals to self._X_train
| :, 0 | ||
| ] | ||
| cov_Y_Y_chol_inv_Y = torch.linalg.solve_triangular( | ||
| cov_Y_Y_chol, self._y_train[:, None], upper=False |
There was a problem hiding this comment.
[TODO]: Check if Y equals to self._y_train
| beta=acqf_params.beta, | ||
| acqf_stabilizing_noise=acqf_params.acqf_stabilizing_noise, | ||
| acqf_params_for_constraints=acqf_params_for_constraints, | ||
| class BaseAcquisitionFunc(ABC): |
There was a problem hiding this comment.
[TODO]: Verify that each member function in every AcquisitionFunction subclass corresponds to the appropriate function in the previous implementation.
There was a problem hiding this comment.
Checked LogEI, UCB, LCB, and LGEHVI
There was a problem hiding this comment.
Checked ConstrainedLogEI and LogPI
| return f_val | ||
|
|
||
|
|
||
| def eval_acqf_no_grad(acqf_params: AcquisitionFunctionParams, x: np.ndarray) -> np.ndarray: |
There was a problem hiding this comment.
[TODO]: Review the implementation of eval_acqf_no_grad in this PR.
| return eval_acqf(acqf_params, torch.from_numpy(x)).detach().numpy() | ||
|
|
||
|
|
||
| def eval_acqf_with_grad( |
There was a problem hiding this comment.
[TODO]: Review the implementation of eval_acqf_with_grad in this PR.
note: A slight difference in results between the master branch and this PR was observed during multi-objective optimization. This discrepancy stems from minor errors in numerical calculations between NumPy and PyTorch. This difference is negligible and can be safely ignored in this context. |
sawa3030
left a comment
There was a problem hiding this comment.
Thank you for your support in reviewing this PR. LGTM
|
@sawa3030 @contramundum53 |
Motivation
This PR introduces the abstraction of the acquisition function for
GPSampler.Description of the changes