Skip to content

Add counts option to sample() #870

@rhettinger

Description

@rhettinger

Add a counts option to more_itertools.sample():

  • Improve compatibility with random.sample().
  • Improve ease of use for populations with repeated values.
  • Fewer iterations than with the equivalent expanded input.
  • Same lazy consumption of inputs.
  • Still memory friendly. Tracks only W, the reservoir, current element, and number remaining.
  • Easily model urn problems.

These are equivalent calls:

  • sample(['red', 'blue'], counts=[4, 2], k=5)
  • sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)
  • sample(run_length.decode(zip(['red', 'blue'], [4, 2])), k=5)

Only the first variant avoids looping over full the expanded input.

New helper function:

def _sample_counted(population, k, counts, strict):

    element = None
    remaining = 0

    def feed(i):
        'Advance *i* steps ahead and consume an element'
        nonlocal element, remaining

        while i + 1 > remaining:
            i = i - remaining
            element = next(population)
            remaining = next(counts)
        remaining -= (i + 1)
        return element

    with suppress(StopIteration):

        reservoir = []
        for _ in range(k):
            reservoir.append(feed(0))
        if strict and len(reservoir) < k:
            raise ValueError('Sample larger than population')

        W = 1.0
        while True:
            W *= exp(log(random()) / k)
            skip = floor(log(random()) / log1p(-W))
            element = feed(skip)
            reservoir[randrange(k)] = element

    shuffle(reservoir)
    return reservoir

Updated calling code:

def sample(iterable, k, weights=None, *, counts=None, strict=False):
    iterator = iter(iterable)

    if k < 0:
        raise ValueError('k must be non-negative')
    if k == 0:
        return []

    if weights is not None and counts is not None:
        raise TypeError('weights and counts are mutally exclusive')

    elif weights is not None:
        weights = iter(weights)
        return _sample_weighted(iterator, k, weights, strict)

    elif counts is not None:
        counts = iter(counts)
        return _sample_counted(iterator, k, counts, strict)

    else:
        return _sample_unweighted(iterator, k, strict)

Simple tests:

# Invariant: sample of a counted population matches sample of an equivalent expanded population
population = ('red', 'green', 'blue')
counts = [4, 10, 6]
expanded_population = list(chain.from_iterable(map(repeat, population, counts)))
k = 15
seed(8675309); s1 = sample(population, k=k, counts=counts)
seed(8675309); s2 = sample(expanded_population, k=k)
assert s1 == s2 

# Sample size entire population
assert Counter(sample('uwxyz', 35, counts=(1, 0, 4, 10, 20))) == Counter({'z': 20, 'y': 10, 'x': 4, 'u': 1})

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions