Skip to content

Commit dcd17f1

Browse files
committed
perf fixes
1 parent f7f1e42 commit dcd17f1

1 file changed

Lines changed: 8 additions & 17 deletions

File tree

dask/order.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def dependencies_key(x: Key) -> tuple:
227227
max_heights,
228228
) in metrics.items()
229229
}
230+
pkey_getitem = partition_keys.__getitem__
230231
result: dict[Key, int] = {root: len(dsk) - 1}
231232
i = 0
232233

@@ -241,17 +242,15 @@ def process_runnables(layers_loaded: int = 0) -> None:
241242
runnable_candidates = set_difference(
242243
set_difference(set(runnable), result), seen
243244
)
244-
runnable_sorted = sorted(
245-
runnable_candidates, key=partition_keys.__getitem__, reverse=True
246-
)
245+
runnable_sorted = sorted(runnable_candidates, key=pkey_getitem, reverse=True)
247246
while runnable_sorted:
248247
task = runnable_sorted.pop()
249248
if task in runnable:
250249
if (
251250
len(set_difference(dependents[runnable[task]], result))
252251
> 1 + layers_loaded
253252
):
254-
next_nodes[partition_keys[task]].add(task)
253+
next_nodes[pkey_getitem(task)].add(task)
255254
continue
256255
result[task] = i
257256
i += 1
@@ -261,7 +260,7 @@ def process_runnables(layers_loaded: int = 0) -> None:
261260
if not num_needed[dep]:
262261
runnable_sorted.append(dep)
263262
else:
264-
next_nodes[partition_keys[dep]].add(dep)
263+
next_nodes[pkey_getitem(dep)].add(dep)
265264

266265
layers_loaded = 0
267266
dep_pools = defaultdict(set)
@@ -299,14 +298,14 @@ def process_runnables(layers_loaded: int = 0) -> None:
299298
for dep in deps:
300299
if dep in seen:
301300
continue
302-
pkey = partition_keys[dep]
301+
pkey = pkey_getitem(dep)
303302
dep_pools[pkey].add(dep)
304303
all_keys.append(pkey)
305304
all_keys.sort()
306305
target_key: tuple[int, ...] | None = None
307306
for pkey in reversed(all_keys):
308307
if inner_stack:
309-
target_key = target_key or partition_keys[inner_stack[0]]
308+
target_key = target_key or pkey_getitem(inner_stack[0])
310309
if pkey < target_key:
311310
next_nodes[target_key].update(inner_stack)
312311
inner_stack = list(dep_pools[pkey])
@@ -331,30 +330,22 @@ def process_runnables(layers_loaded: int = 0) -> None:
331330
seen_update(inner_stack)
332331
continue
333332

334-
# This is just for perf reasons to cut the set differences down
335-
for k in list(runnable):
336-
if k in result:
337-
del runnable[k]
338-
seen = seen - set(result)
339-
340333
if inner_stack:
341334
continue
342335

343336
if len(result) == len(dsk):
344337
break
345338

346339
if not is_init_sorted:
347-
# TODO: Is this even worth it?
348340
init_stack = set(init_stack)
349341
init_stack = set_difference(init_stack, result)
350342
if len(init_stack) < 10000:
351343
init_stack = sorted(init_stack, key=initial_stack_key, reverse=True)
352344
else:
353345
init_stack = list(init_stack)
354346
is_init_sorted = True
355-
assert is_init_sorted
356-
assert isinstance(init_stack, list)
357-
inner_stack = [init_stack.pop()]
347+
348+
inner_stack = [init_stack.pop()] # type: ignore[call-overload]
358349
inner_stack_pop = inner_stack.pop
359350

360351
return result

0 commit comments

Comments
 (0)