8080from collections import defaultdict , namedtuple
8181from collections .abc import Mapping , MutableMapping
8282from heapq import heappop , heappush
83- from typing import Any , cast
83+ from typing import Any , Literal , cast , overload
8484
8585from dask .core import get_dependencies , get_deps , getcycle , istask , reverse_dict
8686from dask .typing import Key
8787
8888
89+ @overload
8990def order (
9091 dsk : MutableMapping [Key , Any ],
91- dependencies : MutableMapping [Key , set [Key ]] | None = None ,
92+ dependencies : MutableMapping [Key , set [Key ]] | None ,
93+ group : Literal [False ],
9294) -> dict [Key , int ]:
95+ ...
96+
97+
98+ @overload
99+ def order (
100+ dsk : MutableMapping [Key , Any ],
101+ dependencies : MutableMapping [Key , set [Key ]] | None ,
102+ group : Literal [True ],
103+ ) -> tuple [dict [Key , int ], dict [int , list [Key ]]]:
104+ ...
105+
106+
107+ def order (
108+ dsk : MutableMapping [Key , Any ],
109+ dependencies : MutableMapping [Key , set [Key ]] | None = None ,
110+ group : bool = False ,
111+ ) -> dict [Key , int ] | tuple [dict [Key , int ], dict [int , list [Key ]]]:
93112 if not dsk :
94113 return {}
95-
114+ groups = defaultdict (list )
115+ groups_by_key = dict ()
116+ group_ix = 0
96117 dsk = dict (dsk )
97118
98119 if dependencies is None :
@@ -126,9 +147,14 @@ def _f(*args: Any, **kwargs: Any) -> None:
126147
127148 dsk [root ] = (_f , * root_nodes )
128149 dependencies [root ] = root_nodes
129- o = order (dsk , dependencies )
130- del o [root ]
131- return o
150+ if not group :
151+ o = order (dsk , dependencies , group = group )
152+ del o [root ]
153+ return o
154+ else :
155+ o , g = order (dsk , dependencies , group = group )
156+ del o [root ]
157+ return o , g
132158 root = list (root_nodes )[0 ]
133159 init_stack : dict [Key , tuple ] | set [Key ] | list [Key ]
134160 # Leaf nodes. We choose one--the initial node--for each weakly connected subgraph.
@@ -243,6 +269,7 @@ def process_runnables(layers_loaded: int = 0) -> None:
243269 nonlocal i
244270 runnable_candidates = set_difference (set (runnable ), seen )
245271 runnable_sorted = sorted (runnable_candidates , key = pkey_getitem , reverse = True )
272+ prev = None
246273 while runnable_sorted :
247274 task = runnable_sorted .pop ()
248275 if task in runnable :
@@ -255,6 +282,11 @@ def process_runnables(layers_loaded: int = 0) -> None:
255282 next_nodes [pkey ].add (task )
256283 continue
257284 result [task ] = i
285+ # groups[group_ix].append(task)
286+ group_ix = groups_by_key [runnable .get (task , prev )]
287+ groups [group_ix ].append (task )
288+ groups_by_key [task ] = group_ix
289+ prev = task
258290 runnable .pop (task , None )
259291 i += 1
260292 deps = dependents [task ]
@@ -287,6 +319,8 @@ def process_runnables(layers_loaded: int = 0) -> None:
287319 layers_loaded += 1
288320 continue
289321 result [item ] = i
322+ groups [group_ix ].append (item )
323+ groups_by_key [item ] = group_ix
290324 runnable .pop (item , None )
291325 i += 1
292326 deps = dependents [item ]
@@ -318,6 +352,8 @@ def process_runnables(layers_loaded: int = 0) -> None:
318352 inner_stack = list (dep_pools [pkey ])
319353 inner_stack_pop = inner_stack .pop
320354 seen_update (inner_stack )
355+ if group_ix in groups and len (groups [group_ix ]) > 1 :
356+ group_ix += 1
321357 continue
322358 next_nodes [pkey ].update (dep_pools [pkey ])
323359 heappush (min_key_next_nodes , pkey )
@@ -336,6 +372,8 @@ def process_runnables(layers_loaded: int = 0) -> None:
336372 next_nodes .pop (min_key ), key = dependents_key , reverse = True
337373 )
338374 inner_stack_pop = inner_stack .pop
375+ if group_ix in groups and len (groups [group_ix ]) > 1 :
376+ group_ix += 1
339377 seen_update (inner_stack )
340378 continue
341379
@@ -344,7 +382,9 @@ def process_runnables(layers_loaded: int = 0) -> None:
344382
345383 if len (result ) == len (dsk ):
346384 break
347-
385+ # Increasing here is very conservative
386+ if group_ix in groups and len (groups [group_ix ]) > 1 :
387+ group_ix += 1
348388 if not is_init_sorted :
349389 init_stack = set (init_stack )
350390 init_stack = set_difference (init_stack , result )
@@ -356,8 +396,10 @@ def process_runnables(layers_loaded: int = 0) -> None:
356396
357397 inner_stack = [init_stack .pop ()] # type: ignore[call-overload]
358398 inner_stack_pop = inner_stack .pop
359-
360- return result
399+ if group :
400+ return result , dict (groups )
401+ else :
402+ return result
361403
362404
363405def graph_metrics (
0 commit comments