-
Notifications
You must be signed in to change notification settings - Fork 311
Add counts option to sample() #870
Copy link
Copy link
Closed
Description
Add a counts option to more_itertools.sample():
- Improve compatibility with
random.sample(). - Improve ease of use for populations with repeated values.
- Fewer iterations than with the equivalent expanded input.
- Same lazy consumption of inputs.
- Still memory friendly. Tracks only W, the reservoir, current element, and number remaining.
- Easily model urn problems.
These are equivalent calls:
sample(['red', 'blue'], counts=[4, 2], k=5)sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)sample(run_length.decode(zip(['red', 'blue'], [4, 2])), k=5)
Only the first variant avoids looping over full the expanded input.
New helper function:
def _sample_counted(population, k, counts, strict):
element = None
remaining = 0
def feed(i):
'Advance *i* steps ahead and consume an element'
nonlocal element, remaining
while i + 1 > remaining:
i = i - remaining
element = next(population)
remaining = next(counts)
remaining -= (i + 1)
return element
with suppress(StopIteration):
reservoir = []
for _ in range(k):
reservoir.append(feed(0))
if strict and len(reservoir) < k:
raise ValueError('Sample larger than population')
W = 1.0
while True:
W *= exp(log(random()) / k)
skip = floor(log(random()) / log1p(-W))
element = feed(skip)
reservoir[randrange(k)] = element
shuffle(reservoir)
return reservoir
Updated calling code:
def sample(iterable, k, weights=None, *, counts=None, strict=False):
iterator = iter(iterable)
if k < 0:
raise ValueError('k must be non-negative')
if k == 0:
return []
if weights is not None and counts is not None:
raise TypeError('weights and counts are mutally exclusive')
elif weights is not None:
weights = iter(weights)
return _sample_weighted(iterator, k, weights, strict)
elif counts is not None:
counts = iter(counts)
return _sample_counted(iterator, k, counts, strict)
else:
return _sample_unweighted(iterator, k, strict)
Simple tests:
# Invariant: sample of a counted population matches sample of an equivalent expanded population
population = ('red', 'green', 'blue')
counts = [4, 10, 6]
expanded_population = list(chain.from_iterable(map(repeat, population, counts)))
k = 15
seed(8675309); s1 = sample(population, k=k, counts=counts)
seed(8675309); s2 = sample(expanded_population, k=k)
assert s1 == s2
# Sample size entire population
assert Counter(sample('uwxyz', 35, counts=(1, 0, 4, 10, 20))) == Counter({'z': 20, 'y': 10, 'x': 4, 'u': 1})
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels