Skip to content

Improve numerical performance of sample() #867

@rhettinger

Description

@rhettinger

In more_itertools.sample(), the W parameter rapidly gets very small. As it gets small, the (1.0 - W) computation experiences loss of significance until W falls below eps and all information is lost.

The fix is simple. Change log(1.0 - W) to log1p(-W).

Here is a comparison between the two. The new version gives more accurate results from the outset and never drops to zero.

for e in range(-1, -20, -1):
    W = 10 ** e
    print(W, log1p(-W), log(1 - W), sep='\t')

0.1	-0.10536051565782631	-0.10536051565782628
0.01	-0.010050335853501442	-0.01005033585350145
0.001	-0.0010005003335835335	-0.0010005003335835344
0.0001	-0.00010000500033335834	-0.00010000500033334732
1e-05	-1.0000050000333337e-05	-1.0000050000287824e-05
1e-06	-1.0000005000003334e-06	-1.000000500029089e-06
1e-07	-1.0000000500000033e-07	-1.0000000494736474e-07
1e-08	-1.0000000050000001e-08	-1.0000000100247594e-08
1e-09	-1.0000000005000001e-09	-9.999999722180686e-10
1e-10	-1.00000000005e-10	-1.000000082790371e-10
1e-11	-1.000000000005e-11	-1.000000082745371e-11
1e-12	-1.0000000000005e-12	-9.999778782803785e-13
1e-13	-1.00000000000005e-13	-1.000310945187316e-13
1e-14	-1.000000000000005e-14	-9.99200722162646e-15
1e-15	-1.0000000000000007e-15	-9.992007221626415e-16
1e-16	-1e-16	-1.1102230246251565e-16
1e-17	-1e-17	0.0
1e-18	-1e-18	0.0
1e-19	-1e-19	0.0

The new version is slightly faster:

python3.13 -m timeit -s 'W=0.01' -s 'from math import log, log1p' 'log(1.0 - W)'
10000000 loops, best of 5: 33.1 nsec per loop

% python3.13 -m timeit -s 'W=0.01' -s 'from math import log, log1p' 'log1p(-W)'
10000000 loops, best of 5: 31.4 nsec per loop

The new version can run father before overflow/underflow:

from math import floor, exp, log, log1p
from random import random, seed
from contextlib import suppress

def baseline_skipper(k):
    global W, skip
    W = 1.0
    while True:
        W *= exp(log(random()) / k)
        skip = floor(log(random()) / log(1.0 - W))

def patched_skiper(k):
    global W, skip
    W = 1.0
    while True:
        W *= exp(log(random()) / k)
        skip = floor(log(random()) / log1p(-W))

for skipper in (baseline_skipper, patched_skiper):

    seed(8675309)
    with suppress(ZeroDivisionError, OverflowError):
        skipper(5)
    print(f'=== {skipper.__name__} ===')
    print(f'Failed at: {W=}  {skip.bit_length()=}')
    print()

This outputs:

=== baseline_skipper ===
Failed at: W=4.770429141386115e-17  skip.bit_length()=55

=== patched_skiper ===
Failed at: W=4.69269497434614e-309  skip.bit_length()=1024

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions