Skip to content

Wrong batch behavior with non-round batch division #2055

@jsgounot

Description

@jsgounot

Snakemake version

7.8.2

Describe the bug

When (rule outputs) / (batch number) is not a round number, batches does not behave correctly.

Here is a short procedural translation of the current batch process with a theoretical job having 476 outputs and 100 batches. Each loop output items indexes for each batch.

def get_batch(items, idx, batches):
    
    batch_len = math.floor(len(items) / batches)
    # self.batch is one-based, hence we have to subtract 1
    idx = idx - 1
    i = idx * batch_len
    
    if idx == (batches-1):
        # extend the last batch to cover rest of list
        return idx, i, len(items[i:])
    else:
        return idx, i, len(items[i : i + batch_len])
    
items = list(range(476))
for i in range(1, 101):
    print(i, get_batch(items, i, 100))

Output tail:

90 (89, 356, 4)
91 (90, 360, 4)
92 (91, 364, 4)
93 (92, 368, 4)
94 (93, 372, 4)
95 (94, 376, 4)
96 (95, 380, 4)
97 (96, 384, 4)
98 (97, 388, 4)
99 (98, 392, 4)
100 (99, 396, 80)

You can see that the first index is not following a correct portioning of the items list. Batch 100 having i == 396, leaving 80 outputs for the last batch while others have 4. This is due to the fact that i is based on a rounded value.

Potential fix

Inspired by this stackoverflow post:

Procedural code:

def get_batch(items, idx, batches):
    
    k, m = divmod(len(items), batches)
    
    # self.batch is one-based, hence we have to subtract 1
    idx = idx - 1
    i = idx * k + min(idx, m)
    batch_len = (idx + 1) * k + min(idx + 1, m) 
    
    if idx == batches-1:
        # extend the last batch to cover rest of list
        return idx, i, len(items[i:])
    else:
        return idx, i, len(items[i : batch_len])
    
items = list(range(476))
for i in range(1, 101):
    print(i, get_batch(items, i, 100))
1 (0, 0, 5)
2 (1, 5, 5)
3 (2, 10, 5)
4 (3, 15, 5)
5 (4, 20, 5)
6 (5, 25, 5)
7 (6, 30, 5)
8 (7, 35, 5)
9 (8, 40, 5)
10 (9, 45, 5)
...
90 (89, 432, 4)
91 (90, 436, 4)
92 (91, 440, 4)
93 (92, 444, 4)
94 (93, 448, 4)
95 (94, 452, 4)
96 (95, 456, 4)
97 (96, 460, 4)
98 (97, 464, 4)
99 (98, 468, 4)
100 (99, 472, 4)

Translation to snakemake:

def get_batch(self, items: list):
    """Return the defined batch of the given items.
    Items are usually input files."""
    # make sure that we always consider items in the same order
    if len(items) < self.batches:
        raise WorkflowError(
            "Batching rule {} has less input files than batches. "
            "Please choose a smaller number of batches.".format(self.rulename)
        )
    items = sorted(items)
    k, m = divmod(len(items), self.batches)

    # self.batch is one-based, hence we have to subtract 1
    idx = self.idx - 1
    i = idx * k + min(idx, m)
    batch_len = (idx + 1) * k + min(idx + 1, m) 

    if self.is_final:
        # extend the last batch to cover rest of list
        return items[i:]
    else:
        return items[i : batch_len]

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions