@@ -227,6 +227,7 @@ def dependencies_key(x: Key) -> tuple:
227227 max_heights ,
228228 ) in metrics .items ()
229229 }
230+ pkey_getitem = partition_keys .__getitem__
230231 result : dict [Key , int ] = {root : len (dsk ) - 1 }
231232 i = 0
232233
@@ -241,17 +242,15 @@ def process_runnables(layers_loaded: int = 0) -> None:
241242 runnable_candidates = set_difference (
242243 set_difference (set (runnable ), result ), seen
243244 )
244- runnable_sorted = sorted (
245- runnable_candidates , key = partition_keys .__getitem__ , reverse = True
246- )
245+ runnable_sorted = sorted (runnable_candidates , key = pkey_getitem , reverse = True )
247246 while runnable_sorted :
248247 task = runnable_sorted .pop ()
249248 if task in runnable :
250249 if (
251250 len (set_difference (dependents [runnable [task ]], result ))
252251 > 1 + layers_loaded
253252 ):
254- next_nodes [partition_keys [ task ] ].add (task )
253+ next_nodes [pkey_getitem ( task ) ].add (task )
255254 continue
256255 result [task ] = i
257256 i += 1
@@ -261,7 +260,7 @@ def process_runnables(layers_loaded: int = 0) -> None:
261260 if not num_needed [dep ]:
262261 runnable_sorted .append (dep )
263262 else :
264- next_nodes [partition_keys [ dep ] ].add (dep )
263+ next_nodes [pkey_getitem ( dep ) ].add (dep )
265264
266265 layers_loaded = 0
267266 dep_pools = defaultdict (set )
@@ -299,14 +298,14 @@ def process_runnables(layers_loaded: int = 0) -> None:
299298 for dep in deps :
300299 if dep in seen :
301300 continue
302- pkey = partition_keys [ dep ]
301+ pkey = pkey_getitem ( dep )
303302 dep_pools [pkey ].add (dep )
304303 all_keys .append (pkey )
305304 all_keys .sort ()
306305 target_key : tuple [int , ...] | None = None
307306 for pkey in reversed (all_keys ):
308307 if inner_stack :
309- target_key = target_key or partition_keys [ inner_stack [0 ]]
308+ target_key = target_key or pkey_getitem ( inner_stack [0 ])
310309 if pkey < target_key :
311310 next_nodes [target_key ].update (inner_stack )
312311 inner_stack = list (dep_pools [pkey ])
@@ -331,30 +330,22 @@ def process_runnables(layers_loaded: int = 0) -> None:
331330 seen_update (inner_stack )
332331 continue
333332
334- # This is just for perf reasons to cut the set differences down
335- for k in list (runnable ):
336- if k in result :
337- del runnable [k ]
338- seen = seen - set (result )
339-
340333 if inner_stack :
341334 continue
342335
343336 if len (result ) == len (dsk ):
344337 break
345338
346339 if not is_init_sorted :
347- # TODO: Is this even worth it?
348340 init_stack = set (init_stack )
349341 init_stack = set_difference (init_stack , result )
350342 if len (init_stack ) < 10000 :
351343 init_stack = sorted (init_stack , key = initial_stack_key , reverse = True )
352344 else :
353345 init_stack = list (init_stack )
354346 is_init_sorted = True
355- assert is_init_sorted
356- assert isinstance (init_stack , list )
357- inner_stack = [init_stack .pop ()]
347+
348+ inner_stack = [init_stack .pop ()] # type: ignore[call-overload]
358349 inner_stack_pop = inner_stack .pop
359350
360351 return result
0 commit comments