Merged
Conversation
Contributor
Author
|
@kAIto47802 @sawa3030 Could you review this PR? |
Contributor
Author
not522
approved these changes
Aug 5, 2025
Member
not522
left a comment
There was a problem hiding this comment.
LGTM! I confirmed the speed and the precision are improved.
====
import optuna
def objective(trial):
x = trial.suggest_float("x", -10, 10)
y = trial.suggest_float("y", -10, 10)
z = trial.suggest_float("z", -10, 10, step=0.5)
return x ** 2 + y ** 2 + z ** 2
sampler = optuna.samplers.TPESampler(seed=42)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=2000)master: 7.81s
PR: 7.49s
====
import math
import numpy as np
import matplotlib.pyplot as plt
import mpmath
import optuna
_norm_pdf_C = math.sqrt(2 * math.pi)
_norm_pdf_logC = math.log(_norm_pdf_C)
_ndtri_exp_approx_C = math.sqrt(3) / math.pi
_log_2 = math.log(2)
def _ndtri_exp(y: np.ndarray) -> np.ndarray:
# Flip the sign of y close to zero for better numerical stability and flip back the sign later.
flipped = y > -1e-2
z = y.copy()
z[flipped] = np.log(-np.expm1(y[flipped]))
x = np.empty_like(y)
if (small_inds := np.nonzero(z < -5))[0].size:
x[small_inds] = -np.sqrt(-2.0 * (z[small_inds] + _norm_pdf_logC))
if (moderate_inds := np.nonzero(z >= -5))[0].size:
x[moderate_inds] = -_ndtri_exp_approx_C * np.log(np.expm1(-z[moderate_inds]))
for _ in range(100):
log_ndtr_x = optuna.samplers._tpe._truncnorm._log_ndtr(x)
log_norm_pdf_x = -0.5 * x**2 - _norm_pdf_logC
# NOTE(nabenabe): Use exp(log_ndtr_x - log_norm_pdf_x) instead of ndtr_x / norm_pdf_x for
# numerical stability.
dx = (log_ndtr_x - z) * np.exp(log_ndtr_x - log_norm_pdf_x)
x -= dx
if np.all(np.abs(dx) < 1e-8 * np.abs(x)): # NOTE: rtol controls the precision.
# Equivalent to np.isclose with atol=0.0 and rtol=1e-8.
break
x[flipped] *= -1
# NOTE(nabe): x[y == 0.0] = np.inf, x[np.isneginf(y)] = -np.inf are necessary for the accurate
# computation, but we omit them as the ppf applies clipping, removing the need for them.
return x
def _ndtri_exp_single_mp(y):
a = -1e9
b = +1e9
for _ in range(1000):
m = (a + b) / 2
if mpmath.log(mpmath.ncdf(m)) < y:
a = m
else:
b = m
return (a + b) / 2
mpmath.mp.dps = 100
y = np.asarray([-(10**i) for i in np.arange(-50, 10, 0.1)])
x_master = np.asarray([optuna.samplers._tpe._truncnorm._ndtri_exp_single(yi) for yi in y])
x_pr = _ndtri_exp(y)
x_mp = np.asarray([_ndtri_exp_single_mp(yi) for yi in y])
plt.plot(-y, x_master - x_mp, label="master")
plt.plot(-y, x_pr - x_mp, label="PR")
plt.xscale("log")
plt.legend()
plt.savefig("6229.png")
Contributor
|
This pull request has not seen any recent activity. |
kAIto47802
approved these changes
Aug 13, 2025
Collaborator
kAIto47802
left a comment
There was a problem hiding this comment.
Thank you for the PR! I followed the equation transformations and their implementation, confirming the correctness.
One minor suggestion is to use np.flatnonzero() instead of np.nonzero()[0].
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR vectorizes
ndtri_expfor the future speedup.In principle, we can further speed up
TPESamplerby vectorizing_truncnorm.rvs.To do so, we need to vectorize
ndtri_exp.Another change includes the enhancement in the numerical stability for a large
y.Description of the changes
mathwithnumpyinndtri_exp-xfirst ifxis positive, and then flip the sign later