-
Notifications
You must be signed in to change notification settings - Fork 311
Improve numerical performance of sample() #867
Copy link
Copy link
Closed
Description
In more_itertools.sample(), the W parameter rapidly gets very small. As it gets small, the (1.0 - W) computation experiences loss of significance until W falls below eps and all information is lost.
The fix is simple. Change log(1.0 - W) to log1p(-W).
Here is a comparison between the two. The new version gives more accurate results from the outset and never drops to zero.
for e in range(-1, -20, -1):
W = 10 ** e
print(W, log1p(-W), log(1 - W), sep='\t')
0.1 -0.10536051565782631 -0.10536051565782628
0.01 -0.010050335853501442 -0.01005033585350145
0.001 -0.0010005003335835335 -0.0010005003335835344
0.0001 -0.00010000500033335834 -0.00010000500033334732
1e-05 -1.0000050000333337e-05 -1.0000050000287824e-05
1e-06 -1.0000005000003334e-06 -1.000000500029089e-06
1e-07 -1.0000000500000033e-07 -1.0000000494736474e-07
1e-08 -1.0000000050000001e-08 -1.0000000100247594e-08
1e-09 -1.0000000005000001e-09 -9.999999722180686e-10
1e-10 -1.00000000005e-10 -1.000000082790371e-10
1e-11 -1.000000000005e-11 -1.000000082745371e-11
1e-12 -1.0000000000005e-12 -9.999778782803785e-13
1e-13 -1.00000000000005e-13 -1.000310945187316e-13
1e-14 -1.000000000000005e-14 -9.99200722162646e-15
1e-15 -1.0000000000000007e-15 -9.992007221626415e-16
1e-16 -1e-16 -1.1102230246251565e-16
1e-17 -1e-17 0.0
1e-18 -1e-18 0.0
1e-19 -1e-19 0.0
The new version is slightly faster:
python3.13 -m timeit -s 'W=0.01' -s 'from math import log, log1p' 'log(1.0 - W)'
10000000 loops, best of 5: 33.1 nsec per loop
% python3.13 -m timeit -s 'W=0.01' -s 'from math import log, log1p' 'log1p(-W)'
10000000 loops, best of 5: 31.4 nsec per loop
The new version can run father before overflow/underflow:
from math import floor, exp, log, log1p
from random import random, seed
from contextlib import suppress
def baseline_skipper(k):
global W, skip
W = 1.0
while True:
W *= exp(log(random()) / k)
skip = floor(log(random()) / log(1.0 - W))
def patched_skiper(k):
global W, skip
W = 1.0
while True:
W *= exp(log(random()) / k)
skip = floor(log(random()) / log1p(-W))
for skipper in (baseline_skipper, patched_skiper):
seed(8675309)
with suppress(ZeroDivisionError, OverflowError):
skipper(5)
print(f'=== {skipper.__name__} ===')
print(f'Failed at: {W=} {skip.bit_length()=}')
print()
This outputs:
=== baseline_skipper ===
Failed at: W=4.770429141386115e-17 skip.bit_length()=55
=== patched_skiper ===
Failed at: W=4.69269497434614e-309 skip.bit_length()=1024
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels