Skip to content

Checkpoints that can be done in parallel are evaluated and executed sequentially #439

@zuny26

Description

@zuny26

Snakemake version
Reproducible in 5.19.2

Describe the bug
When a rule depends on multiple checkpoints, i. e. multiple checkpoints.x.get() calls in the input function, the checkpoints are evaluated sequentially. Initially snakemake considers that the rule only depends on the first checkpoint, so the second checkpoint isn't queued up until the DAG reevaluation after the completion of the first checkpoint.

Minimal example

from itertools import product

src_lang = config['src_lang']
trg_lang = config['trg_lang']

rule all:                        
    input: "data/corpus.txt"

checkpoint shard:                    
    output: "data/batches.{lang}"    
    shell: '''                     
        for i in $(seq 0 $(( $RANDOM % 5 + 1))); do
            for j in $(seq 0 $(( $RANDOM % 2 + 1))); do
                mkdir -p data/{wildcards.lang}/$i/$j
                echo 'hello {wildcards.lang}' > data/{wildcards.lang}/$i/$j/text
            done
        done
        ls -d data/{wildcards.lang}/*/* > {output}  
        '''
                                                                
rule combine:                                              
    input:                                                
        l1='data/{src_lang}/{shard}/{src_batch}/text',
        l2='data/{trg_lang}/{shard}/{trg_batch}/text'
    output: 'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined'
    shell: ''' paste {input.l1} {input.l2} > {output} '''
                                                                                                
def get_batches_pairs(src_lang, trg_lang):                                 
    src_batches = []      
    trg_batches = []                                        
    with checkpoints.shard.get(lang=src_lang).output[0].open() as src_f, \
            checkpoints.shard.get(lang=trg_lang).output[0].open() as trg_f:
        for line in src_f:                      
            src_batches.append(line.strip().split('/')[-2:])                                                                                
        for line in trg_f:
           trg_batches.append(line.strip().split('/')[-2:])
    iterator = product(src_batches, trg_batches)
    return [(src_shard, (src_batch, trg_batch)) for ((src_shard, src_batch), (trg_shard, trg_batch)) in iterator if src_shard == trg_shard]

rule corpus:           
    input: lambda wildcards: [f'data/{src_lang}_{trg_lang}/{shard}.{src_batch}_{trg_batch}.combined' for (shard, (src_batch, trg_batch)) in get_batches_pairs(src_lang, trg_lang)]
    output: 'data/corpus.txt'
    shell: ''' cat {input} > {output} '''

When running this example, fist shard(lang=src_lang) will be executed, and when it finishes snakemake will reevaluate the DAG and queue up shard(lang=trg_lang)

Additional context
Related to #16
A possible workaround would be to add all of the checkpoints outputs to the final target,

rule all:
    input: "data/corpus.txt", f'data/batches.{src_lang}', f'data/batches.{trg_lang}'

but that could potentially get messy if there are a lot of outputs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions