-
Notifications
You must be signed in to change notification settings - Fork 634
Description
Snakemake version
Reproducible in 5.19.2
Describe the bug
When a rule depends on multiple checkpoints, i. e. multiple checkpoints.x.get() calls in the input function, the checkpoints are evaluated sequentially. Initially snakemake considers that the rule only depends on the first checkpoint, so the second checkpoint isn't queued up until the DAG reevaluation after the completion of the first checkpoint.
Minimal example
from itertools import product
src_lang = config['src_lang']
trg_lang = config['trg_lang']
rule all:
input: "data/corpus.txt"
checkpoint shard:
output: "data/batches.{lang}"
shell: '''
for i in $(seq 0 $(( $RANDOM % 5 + 1))); do
for j in $(seq 0 $(( $RANDOM % 2 + 1))); do
mkdir -p data/{wildcards.lang}/$i/$j
echo 'hello {wildcards.lang}' > data/{wildcards.lang}/$i/$j/text
done
done
ls -d data/{wildcards.lang}/*/* > {output}
'''
rule combine:
input:
l1='data/{src_lang}/{shard}/{src_batch}/text',
l2='data/{trg_lang}/{shard}/{trg_batch}/text'
output: 'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined'
shell: ''' paste {input.l1} {input.l2} > {output} '''
def get_batches_pairs(src_lang, trg_lang):
src_batches = []
trg_batches = []
with checkpoints.shard.get(lang=src_lang).output[0].open() as src_f, \
checkpoints.shard.get(lang=trg_lang).output[0].open() as trg_f:
for line in src_f:
src_batches.append(line.strip().split('/')[-2:])
for line in trg_f:
trg_batches.append(line.strip().split('/')[-2:])
iterator = product(src_batches, trg_batches)
return [(src_shard, (src_batch, trg_batch)) for ((src_shard, src_batch), (trg_shard, trg_batch)) in iterator if src_shard == trg_shard]
rule corpus:
input: lambda wildcards: [f'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined' for (shard, (src_batch, trg_batch)) in get_batches_pairs(src_lang, trg_lang)]
output: 'data/corpus.txt'
shell: ''' cat {input} > {output} '''
When running this example, fist shard(lang=src_lang) will be executed, and when it finishes snakemake will reevaluate the DAG and queue up shard(lang=trg_lang)
Additional context
Related to #16
A possible workaround would be to add all of the checkpoints outputs to the final target,
rule all:
input: "data/corpus.txt", f'data/batches.{src_lang}', f'data/batches.{trg_lang}'
but that could potentially get messy if there are a lot of outputs.