-
Notifications
You must be signed in to change notification settings - Fork 634
Wrong batch behavior with non-round batch division #2055
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Snakemake version
7.8.2
Describe the bug
When (rule outputs) / (batch number) is not a round number, batches does not behave correctly.
Here is a short procedural translation of the current batch process with a theoretical job having 476 outputs and 100 batches. Each loop output items indexes for each batch.
def get_batch(items, idx, batches):
batch_len = math.floor(len(items) / batches)
# self.batch is one-based, hence we have to subtract 1
idx = idx - 1
i = idx * batch_len
if idx == (batches-1):
# extend the last batch to cover rest of list
return idx, i, len(items[i:])
else:
return idx, i, len(items[i : i + batch_len])
items = list(range(476))
for i in range(1, 101):
print(i, get_batch(items, i, 100))Output tail:
90 (89, 356, 4)
91 (90, 360, 4)
92 (91, 364, 4)
93 (92, 368, 4)
94 (93, 372, 4)
95 (94, 376, 4)
96 (95, 380, 4)
97 (96, 384, 4)
98 (97, 388, 4)
99 (98, 392, 4)
100 (99, 396, 80)
You can see that the first index is not following a correct portioning of the items list. Batch 100 having i == 396, leaving 80 outputs for the last batch while others have 4. This is due to the fact that i is based on a rounded value.
Potential fix
Inspired by this stackoverflow post:
Procedural code:
def get_batch(items, idx, batches):
k, m = divmod(len(items), batches)
# self.batch is one-based, hence we have to subtract 1
idx = idx - 1
i = idx * k + min(idx, m)
batch_len = (idx + 1) * k + min(idx + 1, m)
if idx == batches-1:
# extend the last batch to cover rest of list
return idx, i, len(items[i:])
else:
return idx, i, len(items[i : batch_len])
items = list(range(476))
for i in range(1, 101):
print(i, get_batch(items, i, 100))1 (0, 0, 5)
2 (1, 5, 5)
3 (2, 10, 5)
4 (3, 15, 5)
5 (4, 20, 5)
6 (5, 25, 5)
7 (6, 30, 5)
8 (7, 35, 5)
9 (8, 40, 5)
10 (9, 45, 5)
...
90 (89, 432, 4)
91 (90, 436, 4)
92 (91, 440, 4)
93 (92, 444, 4)
94 (93, 448, 4)
95 (94, 452, 4)
96 (95, 456, 4)
97 (96, 460, 4)
98 (97, 464, 4)
99 (98, 468, 4)
100 (99, 472, 4)
Translation to snakemake:
def get_batch(self, items: list):
"""Return the defined batch of the given items.
Items are usually input files."""
# make sure that we always consider items in the same order
if len(items) < self.batches:
raise WorkflowError(
"Batching rule {} has less input files than batches. "
"Please choose a smaller number of batches.".format(self.rulename)
)
items = sorted(items)
k, m = divmod(len(items), self.batches)
# self.batch is one-based, hence we have to subtract 1
idx = self.idx - 1
i = idx * k + min(idx, m)
batch_len = (idx + 1) * k + min(idx + 1, m)
if self.is_final:
# extend the last batch to cover rest of list
return items[i:]
else:
return items[i : batch_len]Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working