Skip to content

Vectorize ndtri_exp#6229

Merged
nabenabe0928 merged 17 commits intooptuna:masterfrom
nabenabe0928:enhance/vectorize-ndtri-exp
Aug 13, 2025
Merged

Vectorize ndtri_exp#6229
nabenabe0928 merged 17 commits intooptuna:masterfrom
nabenabe0928:enhance/vectorize-ndtri-exp

Conversation

@nabenabe0928
Copy link
Copy Markdown
Contributor

Motivation

This PR vectorizes ndtri_exp for the future speedup.
In principle, we can further speed up TPESampler by vectorizing _truncnorm.rvs.
To do so, we need to vectorize ndtri_exp.
Another change includes the enhancement in the numerical stability for a large y.

Description of the changes

  • Replace math with numpy in ndtri_exp
  • Calculate -x first if x is positive, and then flip the sign later

@nabenabe0928 nabenabe0928 added the enhancement Change that does not break compatibility and not affect public interfaces, but improves performance. label Aug 1, 2025
@nabenabe0928
Copy link
Copy Markdown
Contributor Author

@kAIto47802 @sawa3030 Could you review this PR?

@nabenabe0928 nabenabe0928 assigned not522 and unassigned sawa3030 Aug 4, 2025
@nabenabe0928
Copy link
Copy Markdown
Contributor Author

Let me re-assign from @sawa3030 to @not522 !

Copy link
Copy Markdown
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I confirmed the speed and the precision are improved.

====

import optuna

def objective(trial):
    x = trial.suggest_float("x", -10, 10)
    y = trial.suggest_float("y", -10, 10)
    z = trial.suggest_float("z", -10, 10, step=0.5)
    return x ** 2 + y ** 2 + z ** 2

sampler = optuna.samplers.TPESampler(seed=42)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=2000)

master: 7.81s
PR: 7.49s

====

import math
import numpy as np
import matplotlib.pyplot as plt
import mpmath
import optuna


_norm_pdf_C = math.sqrt(2 * math.pi)
_norm_pdf_logC = math.log(_norm_pdf_C)
_ndtri_exp_approx_C = math.sqrt(3) / math.pi
_log_2 = math.log(2)


def _ndtri_exp(y: np.ndarray) -> np.ndarray:
    # Flip the sign of y close to zero for better numerical stability and flip back the sign later.
    flipped = y > -1e-2
    z = y.copy()
    z[flipped] = np.log(-np.expm1(y[flipped]))
    x = np.empty_like(y)
    if (small_inds := np.nonzero(z < -5))[0].size:
        x[small_inds] = -np.sqrt(-2.0 * (z[small_inds] + _norm_pdf_logC))
    if (moderate_inds := np.nonzero(z >= -5))[0].size:
        x[moderate_inds] = -_ndtri_exp_approx_C * np.log(np.expm1(-z[moderate_inds]))

    for _ in range(100):
        log_ndtr_x = optuna.samplers._tpe._truncnorm._log_ndtr(x)
        log_norm_pdf_x = -0.5 * x**2 - _norm_pdf_logC
        # NOTE(nabenabe): Use exp(log_ndtr_x - log_norm_pdf_x) instead of ndtr_x / norm_pdf_x for
        # numerical stability.
        dx = (log_ndtr_x - z) * np.exp(log_ndtr_x - log_norm_pdf_x)
        x -= dx
        if np.all(np.abs(dx) < 1e-8 * np.abs(x)):  # NOTE: rtol controls the precision.
            # Equivalent to np.isclose with atol=0.0 and rtol=1e-8.
            break
    x[flipped] *= -1
    # NOTE(nabe): x[y == 0.0] = np.inf, x[np.isneginf(y)] = -np.inf are necessary for the accurate
    # computation, but we omit them as the ppf applies clipping, removing the need for them.
    return x


def _ndtri_exp_single_mp(y):
    a = -1e9
    b = +1e9
    for _ in range(1000):
        m = (a + b) / 2
        if mpmath.log(mpmath.ncdf(m)) < y:
            a = m
        else:
            b = m
    return (a + b) / 2

mpmath.mp.dps = 100

y = np.asarray([-(10**i) for i in np.arange(-50, 10, 0.1)])
x_master = np.asarray([optuna.samplers._tpe._truncnorm._ndtri_exp_single(yi) for yi in y])
x_pr = _ndtri_exp(y)
x_mp = np.asarray([_ndtri_exp_single_mp(yi) for yi in y])
plt.plot(-y, x_master - x_mp, label="master")
plt.plot(-y, x_pr - x_mp, label="PR")
plt.xscale("log")
plt.legend()
plt.savefig("6229.png")
6229

@not522 not522 removed their assignment Aug 5, 2025
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Aug 12, 2025
@nabenabe0928 nabenabe0928 removed the stale Exempt from stale bot labeling. label Aug 12, 2025
Copy link
Copy Markdown
Collaborator

@kAIto47802 kAIto47802 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! I followed the equation transformations and their implementation, confirming the correctness.

One minor suggestion is to use np.flatnonzero() instead of np.nonzero()[0].

@nabenabe0928 nabenabe0928 merged commit e6b2e42 into optuna:master Aug 13, 2025
14 checks passed
@nabenabe0928 nabenabe0928 added this to the v4.5.0 milestone Aug 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement Change that does not break compatibility and not affect public interfaces, but improves performance.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants