Skip to content

Use the Newton method instead of bisect in ndtri_exp#6194

Merged
not522 merged 6 commits intooptuna:masterfrom
nabenabe0928:enhance/use-newton-not-bisect-in-ndtri-exp
Jul 8, 2025
Merged

Use the Newton method instead of bisect in ndtri_exp#6194
not522 merged 6 commits intooptuna:masterfrom
nabenabe0928:enhance/use-newton-not-bisect-in-ndtri-exp

Conversation

@nabenabe0928
Copy link
Copy Markdown
Contributor

@nabenabe0928 nabenabe0928 commented Jul 4, 2025

Motivation

Since ndtri_exp is one of the bottleneck in TPESampler, I speeded up the ndtri_exp implementation.
ndtri_exp(y) essentially finds the root of f(x) = log_ndtr(x) - y = 0.
Currently, our implementation uses the binary search, but the binary search is much slower than the Newton method, so I will replace the binary search with the Newton method.

Description of the changes

  • Replace the binary search with the Newton method
  • Introduce the good initial guess for the Newton method
  • Describe the algorithms in the documentation string

The landscape of the initial guess is available below:

overview

zoom

Benchmarking Results

Important

27% speedup 🎉

Note

When using our initial guess, the iteration (the number of log_ndtr_single calls) reduces by 28% in comparison to x=0 with the Newton method 😄
When compared to the binary search, the reduction is 92% 😎

This PR Master
6.04 $\pm$ 0.117 8.27 $\pm$ 0.057
Code
import time
import optuna


optuna.logging.set_verbosity(optuna.logging.CRITICAL)

for seed in range(10):
    print(f"Start with {seed=}")
    sampler = optuna.samplers.TPESampler(seed=42)
    study = optuna.create_study(sampler=sampler)
    start = time.time()
    study.optimize(lambda t: sum(t.suggest_float(f"x{i}", -5, 5)**2 for i in range(10)), n_trials=500)
    print(time.time() - start)
Results by Master
Start with seed=0
7.831433057785034
Start with seed=1
8.079424858093262
Start with seed=2
8.192095518112183
Start with seed=3
8.294391393661499
Start with seed=4
8.307274580001831
Start with seed=5
8.376489162445068
Start with seed=6
8.367579698562622
Start with seed=7
8.417609453201294
Start with seed=8
8.40910816192627
Start with seed=9
8.42704153060913
Results by this PR
Start with seed=0
5.634321689605713
Start with seed=1
5.7635557651519775
Start with seed=2
5.904412508010864
Start with seed=3
6.616066932678223
Start with seed=4
6.8322083950042725
Start with seed=5
6.148069143295288
Start with seed=6
6.042054891586304
Start with seed=7
5.861287593841553
Start with seed=8
5.825707912445068
Start with seed=9
5.786619663238525

@nabenabe0928 nabenabe0928 added the enhancement Change that does not break compatibility and not affect public interfaces, but improves performance. label Jul 4, 2025
@y0z
Copy link
Copy Markdown
Member

y0z commented Jul 4, 2025

@not522 Could you review this PR?

@nabenabe0928
Copy link
Copy Markdown
Contributor Author

This PR passes the following test:

import math
import sys

from optuna.samplers._tpe._truncnorm import _ndtri_exp_single
from scipy.special import ndtri_exp


EPS = sys.float_info.min
for y in [-EPS] + [-10 ** i for i in range(-300, 10)]:
    x = _ndtri_exp_single(y)
    ans = ndtri_exp(y).item()
    diff = abs(x - ans)
    assert math.isclose(x, ans), f"{x=}, {ans=}"
    print(f"{diff=:.2e}, {y=}, {x=}, {ans=}")

Copy link
Copy Markdown
Member

@contramundum53 contramundum53 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

--> x = sqrt(-2 * (y + 1/2 * log(2pi))

For the moderate y, we use Eq. (13), i.e., standard logistic CDF, in the following paper:
- Approximating the Cumulative Distribution Function of the Normal Distribution.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change it to a standard citation format?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@not522 Thank you for bringing this up! I applied your suggestion!

@nabenabe0928 nabenabe0928 added this to the v4.5.0 milestone Jul 7, 2025
Co-authored-by: Naoto Mizuno <naotomizuno@preferred.jp>
Copy link
Copy Markdown
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
I evaluated the error using the following code.

Details
import math
import sys
import numpy as np
import scipy.special
import mpmath
import matplotlib.pyplot as plt


_norm_pdf_C = math.sqrt(2 * math.pi)
_norm_pdf_logC = math.log(_norm_pdf_C)
_ndtri_exp_approx_C = math.sqrt(3) / math.pi


def _ndtr_single(a):
    x = a / 2**0.5

    if x < -1 / 2**0.5:
        y = 0.5 * math.erfc(-x)
    elif x < 1 / 2**0.5:
        y = 0.5 + 0.5 * math.erf(x)
    else:
        y = 1.0 - 0.5 * math.erfc(x)

    return y


def _log_ndtr_single(a):
    if a > 6:
        return -_ndtr_single(-a)
    if a > -20:
        return math.log(_ndtr_single(a))

    log_LHS = -0.5 * a**2 - math.log(-a) - 0.5 * math.log(2 * math.pi)
    last_total = 0.0
    right_hand_side = 1.0
    numerator = 1.0
    denom_factor = 1.0
    denom_cons = 1 / a**2
    sign = 1
    i = 0

    while abs(last_total - right_hand_side) > sys.float_info.epsilon:
        i += 1
        last_total = right_hand_side
        sign = -sign
        denom_factor *= denom_cons
        numerator *= 2 * i - 1
        right_hand_side += sign * numerator * denom_factor

    return log_LHS + math.log(right_hand_side)


def _bisect(f, a, b, c):
    if f(a) > c:
        a, b = b, a
    # In the algorithm, it is assumed that all of (a + b), (a * 2), and (b * 2) are finite.
    for _ in range(100):
        m = (a + b) / 2
        if a == m or b == m:
            return m
        if f(m) < c:
            a = m
        else:
            b = m
    return (a + b) / 2


def _ndtri_exp_single_master(y):
    # TODO(amylase): Justify this constant
    return _bisect(_log_ndtr_single, -100, +100, y)


def _ndtri_exp_single_pr(y):
    if y > -sys.float_info.min:
        return math.inf if y <= 0 else math.nan

    if y > -1e-2:  # Case 1. abs(y) << 1.
        u = -2.0 * math.log(-y)
        x = math.sqrt(u - math.log(u))
    elif y < -5:  # Case 2. abs(y) >> 1.
        x = -math.sqrt(-2.0 * (y + _norm_pdf_logC))
    else:  # Case 3. Moderate y.
        x = -_ndtri_exp_approx_C * math.log(math.exp(-y) - 1)

    log_ndtr_x = math.nan
    for _ in range(100):
        log_ndtr_x = _log_ndtr_single(x)
        log_norm_pdf_x = -0.5 * x**2 - _norm_pdf_logC
        # NOTE(nabenabe): Use exp(log_ndtr_x - log_norm_pdf_x) instead of ndtr_x / norm_pdf_x for
        # numerical stability.
        dx = (log_ndtr_x - y) * math.exp(log_ndtr_x - log_norm_pdf_x)
        x -= dx
        if abs(dx) < 1e-8 * abs(x):  # Equivalent to np.isclose with atol=0.0 and rtol=1e-8.
            break

    return x


def _ndtri_exp_single_mp(y):
    a = -1e9
    b = +1e9
    for _ in range(1000):
        m = (a + b) / 2
        if mpmath.log(mpmath.ncdf(m)) < y:
            a = m
        else:
            b = m
    return (a + b) / 2


mpmath.mp.dps = 100

clips = [(-1000, 0), (-10, 0), (-0.1, 0), (-0.001, 0)]

fig, axes = plt.subplots(2, len(clips)//2, figsize=(9, 5), constrained_layout=True)

for (a, b), axis in zip(clips, axes.ravel()):
    x = np.linspace(a, b, 100, endpoint=False)
    y_scipy = scipy.special.ndtri_exp(x)
    y_master = np.array([_ndtri_exp_single_master(t) for t in x])
    y_pr = np.array([_ndtri_exp_single_pr(t) for t in x])
    y_mp = np.array([_ndtri_exp_single_mp(t) for t in x])
    err_scipy = y_scipy - y_mp
    err_master = y_master - y_mp
    err_pr = y_pr - y_mp

    axis.plot(x, err_scipy, label="SciPy")
    axis.plot(x, err_master, label="master")
    axis.plot(x, err_pr, label="PR")
    axis.grid()
    axis.legend(loc='lower left')

plt.savefig("ndtri_exp.png")

ndtri_exp

@not522 not522 merged commit 5b54028 into optuna:master Jul 8, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement Change that does not break compatibility and not affect public interfaces, but improves performance.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants