From 27bb6f5ef1949a1d8f6346ef2217035d297d9986 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 11 Oct 2023 17:37:57 +0200 Subject: [PATCH 1/8] more fixes to dask.order --- dask/base.py | 9 +- dask/order.py | 111 ++++++++++++++---------- dask/tests/test_order.py | 177 ++++++++++++++++++++++++++++++--------- 3 files changed, 209 insertions(+), 88 deletions(-) diff --git a/dask/base.py b/dask/base.py index 044beb7ac6c..b2bdafe7b42 100644 --- a/dask/base.py +++ b/dask/base.py @@ -637,6 +637,7 @@ def visualize( optimize_graph=False, maxval=None, engine: Literal["cytoscape", "ipycytoscape", "graphviz"] | None = None, + o=None, **kwargs, ): """ @@ -718,9 +719,10 @@ def visualize( https://docs.dask.org/en/latest/optimize.html """ - args, _ = unpack_collections(*args, traverse=traverse) + dsk = args[0] + # args, _ = unpack_collections(*args, traverse=traverse) - dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph)) + # dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph)) color = kwargs.get("color") @@ -741,7 +743,8 @@ def visualize( from dask.order import diagnostics, order - o = order(dsk) + if o is None: + o = order(dsk) try: cmap = kwargs.pop("cmap") except KeyError: diff --git a/dask/order.py b/dask/order.py index 9a867866027..2fbbd46abf9 100644 --- a/dask/order.py +++ b/dask/order.py @@ -337,9 +337,7 @@ def dependencies_key(x: Key) -> tuple: # which is heavily dependent on the ordering obtained here. singles: dict[Key, Key] = {} singles_clear = singles.clear - later_singles: list[Key] = [] - later_singles_append = later_singles.append - later_singles_clear = later_singles.clear + later_singles = dict() # Priority of being processed # 1. inner_stack @@ -386,6 +384,7 @@ def dependencies_key(x: Key) -> tuple: continue process_singles = True else: + print(f"Set result from inner_stack {item=} {i=}") result[item] = i i += 1 deps = dependents[item] @@ -403,42 +402,14 @@ def dependencies_key(x: Key) -> tuple: continue elif singles: process_singles = True - elif later_singles: - # No need to be optimistic: all nodes in `later_singles` will free a dependency - # when run, so no need to check whether dependents are in `seen`. - for single in later_singles: - if single in result: - continue - while True: - deps_singles = dependents[single] - result[single] = i - i += 1 - if deps_singles: - for dep in deps_singles: - num_needed[dep] -= 1 - if len(deps_singles) == 1: - # Fast path! We trim down `dep2` above hoping to reach here. - (single,) = deps_singles - if not num_needed[single]: - # Keep it going! - deps_singles = dependents[single] - continue - deps |= deps_singles - del deps_singles - break - later_singles_clear() - deps = set_difference(deps, result) - if not deps: - continue - add_to_inner_stack = False - process_singles = True else: break if process_singles and singles: # We gather all dependents of all singles into `deps`, which we then process below. - add_to_inner_stack = True if inner_stack or inner_stacks else False + add_to_inner_stack = True + # if inner_stack or inner_stacks else False singles_keys = set_difference(set(singles), result) # NOTE: If this was too slow, LIFO would be a decent @@ -463,10 +434,12 @@ def dependencies_key(x: Key) -> tuple: ) > 1 ): - later_singles_append(single) + print(f"Skipping single {single=} {parent=}") + later_singles[single] = parent continue while True: deps_singles = dependents[single] + print(f"Set result from single {single=} {i=}") result[single] = i i += 1 if deps_singles: @@ -482,7 +455,6 @@ def dependencies_key(x: Key) -> tuple: if len(already_seen) == 1: (single,) = already_seen if not num_needed[single]: - deps_singles = dependents[single] continue break deps_singles = deps_singles - already_seen @@ -496,7 +468,10 @@ def dependencies_key(x: Key) -> tuple: # Keep it going! deps_singles = dependents[single] continue - later_singles_append(single) + print( + f"Setting later single after walking singles {single=} {parent=}" + ) + later_singles[single] = parent break deps |= deps_singles del deps_singles @@ -521,6 +496,7 @@ def dependencies_key(x: Key) -> tuple: if len(already_seen) == 1: (dep,) = already_seen if not num_needed[dep]: + print(f"[already seen] Set single {dep=} {item=}") singles[dep] = item del dep continue @@ -539,6 +515,7 @@ def dependencies_key(x: Key) -> tuple: if not num_needed[dep]: # We didn't put the single dependency on the stack, but we should still # run it soon, because doing so may free its parent. + print(f"[single dep] Set single {dep=} {item=}") singles[dep] = item else: next_nodes[key].append(deps) @@ -570,8 +547,10 @@ def dependencies_key(x: Key) -> tuple: seen_update(deps) if not num_needed[dep2]: if process_singles: - later_singles_append(dep2) + print(f"later_single key2< {dep2=} {item=}") + later_singles[dep2] = item else: + print(f"[key2<] Set single {dep2=} {item=}") singles[dep2] = item elif key < prev_key: inner_stacks_append(inner_stack) @@ -580,8 +559,10 @@ def dependencies_key(x: Key) -> tuple: seen_add(dep) if not num_needed[dep2]: if process_singles: - later_singles_append(dep2) + print(f"later_single key< {dep2=} {item=}") + later_singles[dep2] = item else: + print(f"[key<] Set single {dep2=} {item=}") singles[dep2] = item else: next_nodes[key2].append([dep2]) @@ -590,8 +571,10 @@ def dependencies_key(x: Key) -> tuple: for k, d in [(key, dep), (key2, dep2)]: if not num_needed[d]: if process_singles: - later_singles_append(d) + print(f"later_single else {d=} {item=}") + later_singles[d] = item else: + print(f"[else<] Set single {d=} {item=}") singles[d] = item else: next_nodes[k].append([d]) @@ -604,6 +587,7 @@ def dependencies_key(x: Key) -> tuple: inner_stack_pop = inner_stack.pop seen_add(dep) if not num_needed[dep2]: + print(f"[no stack] Set single {dep2=} {item=}") singles[dep2] = item elif key == key2 and 5 * partition_keys[item] > 22 * key: inner_stacks_append([dep2]) @@ -635,6 +619,7 @@ def dependencies_key(x: Key) -> tuple: else: psingles = possible_singles[key] for s in psingles: + print(f"[many singles] Set single {s=} {item=}") singles[s] = item vals -= psingles next_nodes[key].append(vals) @@ -664,8 +649,9 @@ def dependencies_key(x: Key) -> tuple: min_key = min(dep_pools) min_pool = dep_pools.pop(min_key) if len(min_pool) == 1: - inner_stack = list(min_pool) - seen_update(inner_stack) + new_stack = list(min_pool) + seen_update(new_stack) + inner_stacks_extend([new_stack]) elif ( 10 * item_key > 11 * len(min_pool) * len(min_pool) * min_key ): @@ -706,6 +692,9 @@ def dependencies_key(x: Key) -> tuple: if len(dependencies) == len(result): break # all done! + # TODO: Perf: Set differences will be cheaper and seen tells us what's + # on the inner stack. Should rename before enabling this. + # seen.clear() if next_nodes: for key in sorted(next_nodes, reverse=True): # `outer_stacks` may not be empty here--it has data from previous `next_nodes`. @@ -714,17 +703,45 @@ def dependencies_key(x: Key) -> tuple: outer_stack_extend(list(el) for el in reversed(next_nodes[key])) next_nodes.clear() - outer_deps = [] while outer_stack: # Try to add a few items to `inner_stacks` outer_deps = [x for x in outer_stack_pop() if x not in result] if outer_deps: if 1 < len(outer_deps) < 100: outer_deps.sort(key=dependents_key, reverse=True) - inner_stacks_extend([dep] for dep in outer_deps) - seen_update(outer_deps) + new_stack = [outer_deps.pop()] + inner_stacks_extend([new_stack]) + seen_update(new_stack) + outer_stack_extend([outer_deps]) + del new_stack break - del outer_deps + del outer_deps + + if later_singles: + later_singles_keys = set_difference(set(later_singles), result) + for single in sorted(later_singles_keys, key=lambda x: partition_keys[x]): + while True: + deps_singles = dependents[single] + print(f"Set result from later single {single=} {i=}") + result[single] = i + i += 1 + if deps_singles: + for dep in deps_singles: + num_needed[dep] -= 1 + if len(deps_singles) == 1: + # Fast path! We trim down `dep2` above hoping to reach here. + (single,) = deps_singles + if not num_needed[single]: + # Keep it going! + deps_singles = dependents[single] + continue + deps |= deps_singles + del deps_singles + break + later_singles.clear() + deps = set_difference(deps, result) + if not deps: + continue if inner_stacks: continue @@ -1068,8 +1085,10 @@ def _convert_task(task: Any) -> Any: elif isinstance(el, list): new_spec.append([_convert_task(e) for e in el]) return (_f, *new_spec) + elif isinstance(task, tuple): + return (_f, task) else: - return task + return (_f, *task) def sanitize_dsk(dsk: MutableMapping[Key, Any]) -> dict: diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index de43cc14a37..6290be5fe9f 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -3,13 +3,19 @@ import pytest import dask +from dask import delayed from dask.base import collections_to_dsk from dask.core import get_deps from dask.order import diagnostics, ndependencies, order from dask.utils_test import add, inc -@pytest.fixture(params=["abcde", "edcba"]) +@pytest.fixture( + params=[ + "abcde", + "edcba", + ] +) def abcde(request): return request.param @@ -743,23 +749,30 @@ def test_order_with_equal_dependents(abcde): (x, 6, i, 1): (f, (x, 5, i, 1)), } ) - o = order(dsk) - total = 0 - for x in abc: - for i in range(len(abc)): - val = o[(x, 6, i, 1)] - o[(x, 6, i, 0)] - assert val > 0 # ideally, val == 2 - total += val - assert total <= 56 # ideally, this should be 2 * 16 == 32 - pressure = diagnostics(dsk, o=o)[1] - assert max(pressure) <= max_pressure - + # o = order(dsk) + # total = 0 + # for x in abc: + # for i in range(len(abc)): + # val = o[(x, 6, i, 1)] - o[(x, 6, i, 0)] + # assert val > 0 # ideally, val == 2 + # total += val + from dask.base import visualize + + # # visualize(dsk, filename="test_order_with_equal_dependents-good", color='order') + # assert total <= 74 # ideally, this should be 2 * 16 == 32 + # pressure = diagnostics(dsk, o=o)[1] + # assert max(pressure) <= max_pressure # Add one to the end of the nine bundles dsk2 = dict(dsk) for x in abc: for i in range(len(abc)): dsk2[(x, 7, i, 0)] = (f, (x, 6, i, 0)) o = order(dsk2) + visualize( + dsk, + filename="test_order_with_equal_dependents", + # color='order' + ) total = 0 for x in abc: for i in range(len(abc)): @@ -770,33 +783,33 @@ def test_order_with_equal_dependents(abcde): pressure = diagnostics(dsk2, o=o)[1] assert max(pressure) <= max_pressure - # Remove one from each of the nine bundles - dsk3 = dict(dsk) - for x in abc: - for i in range(len(abc)): - del dsk3[(x, 6, i, 1)] - o = order(dsk3) - total = 0 - for x in abc: - for i in range(len(abc)): - val = o[(x, 5, i, 1)] - o[(x, 6, i, 0)] - assert val > 0 - total += val - assert total <= 45 # ideally, this should be 2 * 16 == 32 - pressure = diagnostics(dsk3, o=o)[1] - assert max(pressure) <= max_pressure - - # Remove another one from each of the nine bundles - dsk4 = dict(dsk3) - for x in abc: - for i in range(len(abc)): - del dsk4[(x, 6, i, 0)] - o = order(dsk4) - pressure = diagnostics(dsk4, o=o)[1] - assert max(pressure) <= max_pressure - for x in abc: - for i in range(len(abc)): - assert abs(o[(x, 5, i, 1)] - o[(x, 5, i, 0)]) <= 10 + # # Remove one from each of the nine bundles + # dsk3 = dict(dsk) + # for x in abc: + # for i in range(len(abc)): + # del dsk3[(x, 6, i, 1)] + # o = order(dsk3) + # total = 0 + # for x in abc: + # for i in range(len(abc)): + # val = o[(x, 5, i, 1)] - o[(x, 6, i, 0)] + # assert val > 0 + # total += val + # assert total <= 46 # ideally, this should be 2 * 16 == 32 + # pressure = diagnostics(dsk3, o=o)[1] + # assert max(pressure) <= max_pressure + + # # Remove another one from each of the nine bundles + # dsk4 = dict(dsk3) + # for x in abc: + # for i in range(len(abc)): + # del dsk4[(x, 6, i, 0)] + # o = order(dsk4) + # pressure = diagnostics(dsk4, o=o)[1] + # assert max(pressure) <= max_pressure + # for x in abc: + # for i in range(len(abc)): + # assert abs(o[(x, 5, i, 1)] - o[(x, 5, i, 0)]) <= 10 def test_terminal_node_backtrack(): @@ -872,7 +885,20 @@ def test_array_store_final_order(tmpdir): dest = root.empty_like(name="dest", data=x, chunks=x.chunksize, overwrite=True) d = x.store(dest, lock=False, compute=False) o = order(d.dask) + from dask.order import sanitize_dsk + visualize( + sanitize_dsk(collections_to_dsk([d])), + filename="test_array_store_final_order-order", + color="order", + o=o, + ) + visualize( + sanitize_dsk(collections_to_dsk([d])), + filename="test_array_store_final_order", + # color='order', + # o=o + ) # Find the lowest store. Dask starts here. stores = [k for k in o if isinstance(k, tuple) and k[0].startswith("store-map-")] first_store = min(stores, key=lambda k: o[k]) @@ -996,6 +1022,7 @@ def test_diagnostics(abcde): / |/ \|/ \|/ \|/ a0 b0 c0 d0 e0 """ + print("") a, b, c, d, e = abcde dsk = { (a, 0): (f,), @@ -1667,4 +1694,76 @@ def test_flox_reduction(): ("F2", 2): (f, "A0", ("EE", 1)), } o = order(dsk) + visualize( + dsk, + filename="test_flox_reduction.png", + optimize_graph=True, + ) + visualize( + dsk, + filename="test_flox_reduction-order.png", + optimize_graph=True, + color="order", + o=o, + ) assert max(o[("F1", ix)] for ix in range(3)) < min(o[("F2", ix)] for ix in range(3)) + + +import numpy as np + +import dask.array as da +from dask.base import key_split, visualize + + +def test_reduce_with_many_common_dependents(): + ndeps = 3 + + def random(**kwargs): + assert len(kwargs) == ndeps + return np.random.random((10, 10)) + + trivial_deps = { + f"k{i}": delayed(object(), name=f"object-{i}") for i in range(ndeps) + } + n_reducers = 4 + x = da.blockwise( + random, + "yx", + new_axes={"y": (10,) * n_reducers, "x": (10,) * n_reducers}, + dtype=float, + **trivial_deps, + ) + graph = x.sum(axis=1, split_every=20) + from dask.order import order + + dsk = collections_to_dsk([graph]) + dependencies, dependents = get_deps(dsk) + # Verify assumptions + o = order(dsk) + # Verify assumptions (specifically that the reducers are sum-aggregate) + assert {key_split(k) for k in o} == {"object", "sum", "sum-aggregate"} + + reducers = {k for k in o if key_split(k) == "sum-aggregate"} + drift = dict() + for r in reducers: + prios_deps = [] + for dep in dependencies[r]: + prios_deps.append(o[dep]) + drift[r] = (min(prios_deps), max(prios_deps)) + # assert max(prios_deps) - min(prios_deps) == len(dependencies[r]) - 1 + + print(f"{drift=}") + from dask.base import visualize + + visualize( + collections_to_dsk([graph]), + filename="test_decide_worker_coschedule_order_neighbors-color.png", + optimize_graph=True, + color="order", + o=o, + ) + visualize( + collections_to_dsk([graph]), + filename="test_decide_worker_coschedule_order_neighbors.png", + optimize_graph=True, + ) From bbfa1d8c3804d37cd8f9c4df92e2adc2dcfbea42 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 12 Oct 2023 14:06:33 +0200 Subject: [PATCH 2/8] Rewrite dask.order --- dask/order.py | 638 +++++++-------------------------------- dask/tests/test_order.py | 140 +++------ 2 files changed, 151 insertions(+), 627 deletions(-) diff --git a/dask/order.py b/dask/order.py index 2fbbd46abf9..3047c77e05b 100644 --- a/dask/order.py +++ b/dask/order.py @@ -79,7 +79,6 @@ """ from collections import defaultdict, namedtuple from collections.abc import Mapping, MutableMapping -from math import log from typing import Any, cast from dask.core import get_dependencies, get_deps, getcycle, istask, reverse_dict @@ -90,30 +89,6 @@ def order( dsk: MutableMapping[Key, Any], dependencies: MutableMapping[Key, set[Key]] | None = None, ) -> dict[Key, int]: - """Order nodes in dask graph - - This produces an ordering over our tasks that we use to break ties when - executing. We do this ahead of time to reduce a bit of stress on the - scheduler and also to assist in static analysis. - - This currently traverses the graph as a single-threaded scheduler would - traverse it. It breaks ties in the following ways: - - 1. Begin at a leaf node that is a dependency of a root node that has the - largest subgraph (start hard things first) - 2. Prefer tall branches with few dependents (start hard things first and - try to avoid memory usage) - 3. Prefer dependents that are dependencies of root nodes that have - the smallest subgraph (do small goals that can terminate quickly) - - Examples - -------- - >>> inc = lambda x: x + 1 - >>> add = lambda x, y: x + y - >>> dsk = {'a': 1, 'b': 2, 'c': (inc, 'a'), 'd': (add, 'b', 'c')} - >>> order(dsk) - {'a': 0, 'c': 1, 'b': 2, 'd': 3} - """ if not dsk: return {} @@ -153,7 +128,7 @@ def _f(*args: Any, **kwargs: Any) -> None: o = order(dsk, dependencies) del o[root] return o - + root = list(root_nodes)[0] init_stack: dict[Key, tuple] | set[Key] | list[Key] # Leaf nodes. We choose one--the initial node--for each weakly connected subgraph. # Let's calculate the `initial_stack_key` as we determine `init_stack` set. @@ -182,6 +157,7 @@ def _f(*args: Any, **kwargs: Any) -> None: if not val ) } + is_init_sorted = False # `initial_stack_key` chooses which task to run at the very beginning. # This value is static, so we pre-compute as the value of this dict. initial_stack_key = init_stack.__getitem__ @@ -233,548 +209,154 @@ def dependencies_key(x: Key) -> tuple: StrComparable(x), ) + seen = set(root_nodes) + seen_update = seen.update root_total_dependencies = total_dependencies[list(root_nodes)[0]] # Computing this for all keys can sometimes be relatively expensive :( partition_keys = { key: ( - (root_total_dependencies - total_dependencies[key] + 1) - * (total_dependents - min_heights) + (root_total_dependencies - total_dependencies[key] + 1), + (total_dependents - min_heights), + -max_heights, ) for key, ( total_dependents, _, _, min_heights, - _, + max_heights, ) in metrics.items() } - - result: dict[Key, int] = {} + result: dict[Key, int] = {root: len(dsk) - 1} i = 0 - # `inner_stack` is used to perform a DFS along dependencies. Once emptied - # (when traversing dependencies), this continue down a path along dependents - # until a root node is reached. - # - # Sometimes, a better path along a dependent is discovered (i.e., something - # that is easier to compute and doesn't requiring holding too much in memory). - # In this case, the current `inner_stack` is appended to `inner_stacks` and - # we begin a new DFS from the better node. - # - # A "better path" is determined by comparing `partition_keys`. inner_stack = [min(init_stack, key=initial_stack_key)] inner_stack_pop = inner_stack.pop - inner_stacks: list[list[Key]] = [] - inner_stacks_append = inner_stacks.append - inner_stacks_extend = inner_stacks.extend - inner_stacks_pop = inner_stacks.pop - - # Okay, now we get to the data structures used for fancy behavior. - # - # As we traverse nodes in the DFS along dependencies, we partition the dependents - # via `partition_key`. A dependent goes to: - # 1) `inner_stack` if it's better than our current target, - # 2) `next_nodes` if the partition key is lower than it's parent, - # When the inner stacks are depleted, we process `next_nodes`. - # These dicts use `partition_keys` as keys. We process them by placing the values - # in `outer_stack` so that the smallest keys will be processed first. - next_nodes: defaultdict[int, list[list[Key] | set[Key]]] = defaultdict(list) - - # `outer_stack` is used to populate `inner_stacks`. From the time we partition the - # dependents of a node, we group them: one list per partition key per parent node. - # This likely results in many small lists. We do this to avoid sorting many larger - # lists (i.e., to avoid n*log(n) behavior). So, we have many small lists that we - # partitioned, and we keep them in the order that we saw them (we will process them - # in a FIFO manner). By delaying sorting for as long as we can, we can first filter - # out nodes that have already been computed. All this complexity is worth it! - outer_stack: list[list[Key]] = [] - outer_stack_extend = outer_stack.extend - outer_stack_pop = outer_stack.pop - - # Keep track of nodes that are in `inner_stack` or `inner_stacks` so we don't - # process them again. - seen = set(root_nodes) - seen_update = seen.update - seen_add = seen.add - - # "singles" are tasks that are available to run, and when run may free a dependency. - # Although running a task to free a dependency may seem like a wash (net zero), it - # can be beneficial by providing more opportunities for a later task to free even - # more data. So, it costs us little in the short term to more eagerly compute - # chains of tasks that keep the same number of data in memory, and the longer term - # rewards are potentially high. I would expect a dynamic scheduler to have similar - # behavior, so I think it makes sense to do the same thing here in `dask.order`. - # - # When we gather tasks in `singles`, we do so optimistically: running the task *may* - # free the parent, but it also may not, because other dependents of the parent may - # be in the inner stacks. When we process `singles`, we run tasks that *will* free - # the parent, otherwise we move the task to `later_singles`. `later_singles` is run - # when there are no inner stacks, so it is safe to run all of them (because no other - # dependents will be hiding in the inner stacks to keep hold of the parent). - # `singles` is processed when the current item on the stack needs to compute - # dependencies before it can be run. - # - # Processing singles is meant to be a detour. Doing so should not change our - # tactical goal in most cases. Hence, we set `add_to_inner_stack = False`. - # - # In reality, this is a pretty limited strategy for running a task to free a - # dependency. A thorough strategy would be to check whether running a dependent - # with `num_needed[dep] == 0` would free *any* of its dependencies. This isn't - # what we do. This data isn't readily or cheaply available. We only check whether - # it will free its last dependency that was computed (the current `item`). This is - # probably okay. In general, our tactics and strategies for ordering try to be - # memory efficient, so we shouldn't try too hard to work around what we already do. - # However, sometimes the DFS nature of it leaves "easy-to-compute" stragglers behind. - # The current approach is very fast to compute, can be beneficial, and is generally - # low-risk. There could be more we could do here though. Our static scheduling - # here primarily looks at "what dependent should we run next?" instead of "what - # dependency should we try to free?" Two sides to the same question, but a dynamic - # scheduler is much better able to answer the latter one, because it knows the size - # of data and can react to current state. Does adding a little more dynamic-like - # behavior to `dask.order` add any tension to running with an actual dynamic - # scheduler? Should we defer to dynamic schedulers and let them behave like this - # if they so choose? Maybe. However, I'm sensitive to the multithreaded scheduler, - # which is heavily dependent on the ordering obtained here. - singles: dict[Key, Key] = {} - singles_clear = singles.clear - later_singles = dict() - - # Priority of being processed - # 1. inner_stack - # 2. singles (may be moved to later_singles) - # 3. inner_stacks - # 4. later_singles - # 5. next_nodes - # 6. outer_stack - # 7. init_stack - - # alias for speed + next_nodes: defaultdict[tuple[int, ...], set[Key]] = defaultdict(set) + runnable: dict[Key, Key] = dict() set_difference = set.difference - is_init_sorted = False + def process_runnables(layers_loaded: int = 0) -> None: + nonlocal i + runnable_candidates = set_difference( + set_difference(set(runnable), result), seen + ) + runnable_sorted = sorted( + runnable_candidates, key=partition_keys.__getitem__, reverse=True + ) + while runnable_sorted: + task = runnable_sorted.pop() - while True: - while True: - # Perform a DFS along dependencies until we complete our tactical goal - deps = set() - add_to_inner_stack = True - if inner_stack: - item = inner_stack_pop() - if item in result: + if task in runnable: + if ( + len(set_difference(dependents[runnable[task]], result)) + > 1 + layers_loaded + ): + next_nodes[partition_keys[task]].add(task) continue - if num_needed[item]: - if item not in root_nodes: - inner_stack.append(item) - deps = set_difference(dependencies[item], result) - if 1 < len(deps) < 1000: - inner_stack.extend( - sorted(deps, key=dependencies_key, reverse=True) - ) - else: - inner_stack.extend(deps) - seen_update(deps) - if not singles: - continue - # Only process singles once the inner_stack is fully - # resolved. This is important because the singles path later - # on verifies that running the single indeed opens an - # opportunity to release soon by comparing the singles - # parent's dependents with the inner_stack(s) - if inner_stack and num_needed[inner_stack[-1]]: - continue - process_singles = True + print(f"Set result from runnable {task=} {i=}") + result[task] = i + i += 1 + deps = dependents[task] + for dep in deps: + num_needed[dep] -= 1 + if not num_needed[dep]: + runnable_sorted.append(dep) else: - print(f"Set result from inner_stack {item=} {i=}") - result[item] = i - i += 1 - deps = dependents[item] - add_to_inner_stack = True - - if deps: - for dep in deps: - num_needed[dep] -= 1 - process_singles = False - else: - continue - elif inner_stacks: - inner_stack = inner_stacks_pop() - inner_stack_pop = inner_stack.pop + next_nodes[partition_keys[dep]].add(dep) + # runnable.clear() + + dep_pools = defaultdict(set) + layers_loaded = 0 + while True: + while inner_stack: + item = inner_stack_pop() + if item in result: continue - elif singles: - process_singles = True - else: - break - - if process_singles and singles: - # We gather all dependents of all singles into `deps`, which we then process below. - - add_to_inner_stack = True - # if inner_stack or inner_stacks else False - singles_keys = set_difference(set(singles), result) - - # NOTE: If this was too slow, LIFO would be a decent - # approximation - for single in sorted(singles_keys, key=lambda x: partition_keys[x]): - # We want to run the singles if they are either releasing a - # dependency directly or that they may be releasing a - # dependency once the current critical path / inner_stack is - # walked. - # By using `seen` here this is more permissive since it also - # includes tasks in a future critical path / inner_stacks - # but it would require additional state to make this - # distinction and we don't have enough data to dermine if - # this is worth it. - parent = singles[single] - if ( - len( - set_difference( - set_difference(dependents[parent], result), - seen, - ) - ) - > 1 - ): - print(f"Skipping single {single=} {parent=}") - later_singles[single] = parent - continue - while True: - deps_singles = dependents[single] - print(f"Set result from single {single=} {i=}") - result[single] = i - i += 1 - if deps_singles: - for dep in deps_singles: - num_needed[dep] -= 1 - if add_to_inner_stack: - already_seen = deps_singles & seen - if already_seen: - # This means that the singles path also - # leads to the current or previous strategic - # path - if len(deps_singles) == len(already_seen): - if len(already_seen) == 1: - (single,) = already_seen - if not num_needed[single]: - continue - break - deps_singles = deps_singles - already_seen - else: - already_seen = set() - if len(deps_singles) == 1: - # Fast path! We trim down `dep2` above hoping to reach here. - (single,) = deps_singles - if not num_needed[single]: - if not already_seen: - # Keep it going! - deps_singles = dependents[single] - continue - print( - f"Setting later single after walking singles {single=} {parent=}" - ) - later_singles[single] = parent - break - deps |= deps_singles - del deps_singles - break - del singles_keys - deps = set_difference(deps, result) - singles_clear() - if not deps: - continue - add_to_inner_stack = False - - # If inner_stack is empty, then we typically add the best dependent to it. - # However, we don't add to it if a dependent is already on an inner_stack. In this case, we add the - # dependents (not in an inner_stack) to next_nodes or later_nodes to handle later. - # This serves three purposes: - # 1. shrink `deps` so that it can be processed faster, - # 2. make sure we don't process the same dependency repeatedly, and - # 3. make sure we don't accidentally continue down an expensive-to-compute path. - already_seen = deps & seen - if already_seen: - if len(deps) == len(already_seen): - if len(already_seen) == 1: - (dep,) = already_seen - if not num_needed[dep]: - print(f"[already seen] Set single {dep=} {item=}") - singles[dep] = item - del dep - continue - add_to_inner_stack = False - deps = deps - already_seen - - if len(deps) == 1: - # Fast path! We trim down `deps` above hoping to reach here. - (dep,) = deps - if add_to_inner_stack and not inner_stack: - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - continue - key = partition_keys[dep] - if not num_needed[dep]: - # We didn't put the single dependency on the stack, but we should still - # run it soon, because doing so may free its parent. - print(f"[single dep] Set single {dep=} {item=}") - singles[dep] = item + if num_needed[item]: + inner_stack.append(item) + deps = set_difference(dependencies[item], result) + if 1 < len(deps) < 1000: + inner_stack.extend(sorted(deps, key=dependencies_key, reverse=True)) else: - next_nodes[key].append(deps) - del dep, key - elif len(deps) == 2: - # We special-case when len(deps) == 2 so that we may place a dep on singles. - # Otherwise, the logic here is the same as when `len(deps) > 2` below. - # - # Let me explain why this is a special case. If we put the better dependent - # onto the inner stack, then it's guaranteed to run next. After it's run, - # then running the other dependent *may* allow their parent to be freed. - dep, dep2 = deps - key = partition_keys[dep] - key2 = partition_keys[dep2] + inner_stack.extend(deps) + seen_update(deps) + if not num_needed[inner_stack[-1]]: + process_runnables(layers_loaded) + layers_loaded += 1 + continue + layers_loaded = 0 + print(f"Set result from inner_stack {item=} {i=}") + result[item] = i + i += 1 + deps = dependents[item] + for dep in deps: + num_needed[dep] -= 1 if ( - key2 < key - or key == key2 - and dependents_key(dep2) < dependents_key(dep) + not num_needed[dep] + # optimization. We skip this anyhow below + and dep not in seen ): - dep, dep2 = dep2, dep - key, key2 = key2, key + runnable[dep] = item + + # Heap? + all_keys = [] + for dep in deps: + pkey = partition_keys[dep] + dep_pools[pkey].add(dep) + all_keys.append(pkey) + all_keys.sort() + target_key: tuple[int, ...] | None = None + for pkey in reversed(all_keys): if inner_stack: - prev_key = partition_keys[inner_stack[0]] - if key2 < prev_key: - inner_stacks_append(inner_stack) - inner_stacks_append([dep2]) - inner_stack = [dep] + target_key = target_key or partition_keys[inner_stack[0]] + if pkey < target_key: + next_nodes[target_key].update(inner_stack) + inner_stack = list(dep_pools[pkey]) inner_stack_pop = inner_stack.pop - seen_update(deps) - if not num_needed[dep2]: - if process_singles: - print(f"later_single key2< {dep2=} {item=}") - later_singles[dep2] = item - else: - print(f"[key2<] Set single {dep2=} {item=}") - singles[dep2] = item - elif key < prev_key: - inner_stacks_append(inner_stack) - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - if not num_needed[dep2]: - if process_singles: - print(f"later_single key< {dep2=} {item=}") - later_singles[dep2] = item - else: - print(f"[key<] Set single {dep2=} {item=}") - singles[dep2] = item - else: - next_nodes[key2].append([dep2]) - else: - item_key = partition_keys[item] - for k, d in [(key, dep), (key2, dep2)]: - if not num_needed[d]: - if process_singles: - print(f"later_single else {d=} {item=}") - later_singles[d] = item - else: - print(f"[else<] Set single {d=} {item=}") - singles[d] = item - else: - next_nodes[k].append([d]) - del item_key - del prev_key - else: - assert not inner_stack - if add_to_inner_stack: - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - if not num_needed[dep2]: - print(f"[no stack] Set single {dep2=} {item=}") - singles[dep2] = item - elif key == key2 and 5 * partition_keys[item] > 22 * key: - inner_stacks_append([dep2]) - seen_add(dep2) - else: - next_nodes[key2].append([dep2]) + seen_update(inner_stack) else: - for k, d in [(key, dep), (key2, dep2)]: - next_nodes[k].append([d]) - del dep, dep2, key, key2 - else: - # Slow path :(. This requires grouping by partition_key. - dep_pools = defaultdict(set) - possible_singles = defaultdict(set) - for dep in deps: - pkey = partition_keys[dep] - if not num_needed[dep] and not process_singles: - possible_singles[pkey].add(dep) - dep_pools[pkey].add(dep) - item_key = partition_keys[item] - if inner_stack: - # If we have an inner_stack, we need to look for a "better" path - prev_key = partition_keys[inner_stack[0]] - now_keys = [] # < inner_stack[0] - psingles = set() - for key, vals in dep_pools.items(): - if key < prev_key: - now_keys.append(key) - else: - psingles = possible_singles[key] - for s in psingles: - print(f"[many singles] Set single {s=} {item=}") - singles[s] = item - vals -= psingles - next_nodes[key].append(vals) - del vals, key - del psingles - if now_keys: - # Run before `inner_stack` (change tactical goal!) - inner_stacks_append(inner_stack) - if 1 < len(now_keys): - now_keys.sort(reverse=True) - for key in now_keys: - pool: set[Key] | list[Key] - pool = dep_pools[key] - if 1 < len(pool) < 100: - pool = sorted(pool, key=dependents_key, reverse=True) - inner_stacks_extend([dep] for dep in pool) - seen_update(pool) - del pool - inner_stack = inner_stacks_pop() - inner_stack_pop = inner_stack.pop - del now_keys, prev_key + pass else: - # If we don't have an inner_stack, then we don't need to look - # for a "better" path, but we do need traverse along dependents. - if add_to_inner_stack: - min_pool: list[Key] | set[Key] - min_key = min(dep_pools) - min_pool = dep_pools.pop(min_key) - if len(min_pool) == 1: - new_stack = list(min_pool) - seen_update(new_stack) - inner_stacks_extend([new_stack]) - elif ( - 10 * item_key > 11 * len(min_pool) * len(min_pool) * min_key - ): - # Put all items in min_pool onto inner_stacks. - # I know this is a weird comparison. Hear me out. - # Although it is often beneficial to put all of the items in `min_pool` - # onto `inner_stacks` to process next, it is very easy to be overzealous. - # Sometimes it is actually better to defer until `next_nodes` is handled. - # We should only put items onto `inner_stacks` that we're reasonably - # confident about. The above formula is a best effort heuristic given - # what we have easily available. It is obviously very specific to our - # choice of partition_key. Dask tests take this route about 40%. - if len(min_pool) < 100: - min_pool = sorted( - min_pool, key=dependents_key, reverse=True - ) - inner_stacks_extend([dep] for dep in min_pool) - inner_stack = inner_stacks_pop() - seen_update(min_pool) - else: - # Put one item in min_pool onto inner_stack and the rest into next_nodes. - if len(min_pool) < 100: - inner_stack = [min(min_pool, key=dependents_key)] - else: - inner_stack = [min_pool.pop()] - next_nodes[min_key].append(min_pool) - seen_update(inner_stack) - del min_pool, min_key - inner_stack_pop = inner_stack.pop - for key, vals in dep_pools.items(): - psingles = possible_singles[key] - for s in psingles: - singles[s] = item - vals -= psingles - next_nodes[key].append(vals) - del key, vals - - if len(dependencies) == len(result): - break # all done! - - # TODO: Perf: Set differences will be cheaper and seen tells us what's - # on the inner stack. Should rename before enabling this. - # seen.clear() - if next_nodes: - for key in sorted(next_nodes, reverse=True): - # `outer_stacks` may not be empty here--it has data from previous `next_nodes`. - # Since we pop things off of it (onto `inner_nodes`), this means we handle - # multiple `next_nodes` in a LIFO manner. - outer_stack_extend(list(el) for el in reversed(next_nodes[key])) - next_nodes.clear() - - while outer_stack: - # Try to add a few items to `inner_stacks` - outer_deps = [x for x in outer_stack_pop() if x not in result] - if outer_deps: - if 1 < len(outer_deps) < 100: - outer_deps.sort(key=dependents_key, reverse=True) - new_stack = [outer_deps.pop()] - inner_stacks_extend([new_stack]) - seen_update(new_stack) - outer_stack_extend([outer_deps]) - del new_stack - break - del outer_deps - - if later_singles: - later_singles_keys = set_difference(set(later_singles), result) - for single in sorted(later_singles_keys, key=lambda x: partition_keys[x]): - while True: - deps_singles = dependents[single] - print(f"Set result from later single {single=} {i=}") - result[single] = i - i += 1 - if deps_singles: - for dep in deps_singles: - num_needed[dep] -= 1 - if len(deps_singles) == 1: - # Fast path! We trim down `dep2` above hoping to reach here. - (single,) = deps_singles - if not num_needed[single]: - # Keep it going! - deps_singles = dependents[single] - continue - deps |= deps_singles - del deps_singles - break - later_singles.clear() - deps = set_difference(deps, result) - if not deps: - continue + next_nodes[pkey].update(dep_pools[pkey]) + dep_pools.clear() + + process_runnables() - if inner_stacks: + if next_nodes and not inner_stack: + min_key = min(next_nodes) + inner_stack = sorted( + next_nodes.pop(min_key), key=dependents_key, reverse=True + ) + inner_stack_pop = inner_stack.pop + seen_update(inner_stack) continue - # We just finished computing a connected group. - # Let's choose the first `item` in the next group to compute. - # If we have few large groups left, then it's best to find `item` by taking a minimum. - # If we have many small groups left, then it's best to sort. - # If we have many tiny groups left, then it's best to simply iterate. + # This is just for perf reasons to cut the set differences down + for k in list(runnable): + if k in result: + del runnable[k] + seen = seen - set(result) + + if inner_stack: + continue + + if len(result) == len(dsk): + break + if not is_init_sorted: - prev_len = len(init_stack) + # TODO: Is this even worth it? init_stack = set(init_stack) init_stack = set_difference(init_stack, result) - N = len(init_stack) - m = prev_len - N - # is `min` likely better than `sort`? - if m >= N or N + (N - m) * log(N - m) < N * log(N): - item = min(init_stack, key=initial_stack_key) - continue - if len(init_stack) < 10000: init_stack = sorted(init_stack, key=initial_stack_key, reverse=True) else: init_stack = list(init_stack) - init_stack_pop = init_stack.pop is_init_sorted = True - - if item in root_nodes: - item = init_stack_pop() - - while item in result: - item = init_stack_pop() - inner_stack.append(item) + assert is_init_sorted + assert isinstance(init_stack, list) + inner_stack = [init_stack.pop()] + inner_stack_pop = inner_stack.pop return result diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index 6290be5fe9f..67ab8ca157c 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -1,10 +1,12 @@ from __future__ import annotations +import numpy as np import pytest import dask +import dask.array as da from dask import delayed -from dask.base import collections_to_dsk +from dask.base import collections_to_dsk, key_split from dask.core import get_deps from dask.order import diagnostics, ndependencies, order from dask.utils_test import add, inc @@ -725,7 +727,7 @@ def test_order_with_equal_dependents(abcde): # Lower pressure is better but this is where we are right now. Important is # that no variation below should be worse since all variations below should # reduce to the same graph when optimized/fused. - max_pressure = 11 + max_pressure = 10 a, b, c, d, e = abcde dsk = {} abc = [a, b, c, d] @@ -749,67 +751,57 @@ def test_order_with_equal_dependents(abcde): (x, 6, i, 1): (f, (x, 5, i, 1)), } ) - # o = order(dsk) - # total = 0 - # for x in abc: - # for i in range(len(abc)): - # val = o[(x, 6, i, 1)] - o[(x, 6, i, 0)] - # assert val > 0 # ideally, val == 2 - # total += val - from dask.base import visualize - - # # visualize(dsk, filename="test_order_with_equal_dependents-good", color='order') - # assert total <= 74 # ideally, this should be 2 * 16 == 32 - # pressure = diagnostics(dsk, o=o)[1] - # assert max(pressure) <= max_pressure + o = order(dsk) + total = 0 + for x in abc: + for i in range(len(abc)): + val = abs(o[(x, 6, i, 1)] - o[(x, 6, i, 0)]) + total += val + + assert total <= 32 # ideally, this should be 2 * 16 == 32 + pressure = diagnostics(dsk, o=o)[1] + assert max(pressure) <= max_pressure # Add one to the end of the nine bundles dsk2 = dict(dsk) for x in abc: for i in range(len(abc)): dsk2[(x, 7, i, 0)] = (f, (x, 6, i, 0)) o = order(dsk2) - visualize( - dsk, - filename="test_order_with_equal_dependents", - # color='order' - ) total = 0 for x in abc: for i in range(len(abc)): - val = o[(x, 6, i, 1)] - o[(x, 7, i, 0)] - assert val > 0 # ideally, val == 3 + val = abs(o[(x, 6, i, 1)] - o[(x, 7, i, 0)]) total += val - assert total <= 75 # ideally, this should be 3 * 16 == 48 + assert total <= 48 # ideally, this should be 3 * 16 == 48 pressure = diagnostics(dsk2, o=o)[1] assert max(pressure) <= max_pressure - # # Remove one from each of the nine bundles - # dsk3 = dict(dsk) - # for x in abc: - # for i in range(len(abc)): - # del dsk3[(x, 6, i, 1)] - # o = order(dsk3) - # total = 0 - # for x in abc: - # for i in range(len(abc)): - # val = o[(x, 5, i, 1)] - o[(x, 6, i, 0)] - # assert val > 0 - # total += val - # assert total <= 46 # ideally, this should be 2 * 16 == 32 - # pressure = diagnostics(dsk3, o=o)[1] - # assert max(pressure) <= max_pressure + # Remove one from each of the nine bundles + dsk3 = dict(dsk) + for x in abc: + for i in range(len(abc)): + del dsk3[(x, 6, i, 1)] + o = order(dsk3) + total = 0 + for x in abc: + for i in range(len(abc)): + val = abs(o[(x, 5, i, 1)] - o[(x, 6, i, 0)]) + total += val + assert total <= 32 # ideally, this should be 2 * 16 == 32 + pressure = diagnostics(dsk3, o=o)[1] + assert max(pressure) <= max_pressure # # Remove another one from each of the nine bundles - # dsk4 = dict(dsk3) - # for x in abc: - # for i in range(len(abc)): - # del dsk4[(x, 6, i, 0)] - # o = order(dsk4) - # pressure = diagnostics(dsk4, o=o)[1] - # assert max(pressure) <= max_pressure - # for x in abc: - # for i in range(len(abc)): - # assert abs(o[(x, 5, i, 1)] - o[(x, 5, i, 0)]) <= 10 + dsk4 = dict(dsk3) + for x in abc: + for i in range(len(abc)): + del dsk4[(x, 6, i, 0)] + o = order(dsk4) + pressure = diagnostics(dsk4, o=o)[1] + assert max(pressure) <= max_pressure + for x in abc: + for i in range(len(abc)): + assert abs(o[(x, 5, i, 1)] - o[(x, 5, i, 0)]) <= 2 def test_terminal_node_backtrack(): @@ -885,20 +877,6 @@ def test_array_store_final_order(tmpdir): dest = root.empty_like(name="dest", data=x, chunks=x.chunksize, overwrite=True) d = x.store(dest, lock=False, compute=False) o = order(d.dask) - from dask.order import sanitize_dsk - - visualize( - sanitize_dsk(collections_to_dsk([d])), - filename="test_array_store_final_order-order", - color="order", - o=o, - ) - visualize( - sanitize_dsk(collections_to_dsk([d])), - filename="test_array_store_final_order", - # color='order', - # o=o - ) # Find the lowest store. Dask starts here. stores = [k for k in o if isinstance(k, tuple) and k[0].startswith("store-map-")] first_store = min(stores, key=lambda k: o[k]) @@ -1022,7 +1000,6 @@ def test_diagnostics(abcde): / |/ \|/ \|/ \|/ a0 b0 c0 d0 e0 """ - print("") a, b, c, d, e = abcde dsk = { (a, 0): (f,), @@ -1241,7 +1218,6 @@ def test_anom_mean_raw(): } o = order(dsk) - # The left hand computation branch should complete before we start loading # more data nodes_to_finish_before_loading_more_data = [ @@ -1694,27 +1670,9 @@ def test_flox_reduction(): ("F2", 2): (f, "A0", ("EE", 1)), } o = order(dsk) - visualize( - dsk, - filename="test_flox_reduction.png", - optimize_graph=True, - ) - visualize( - dsk, - filename="test_flox_reduction-order.png", - optimize_graph=True, - color="order", - o=o, - ) assert max(o[("F1", ix)] for ix in range(3)) < min(o[("F2", ix)] for ix in range(3)) -import numpy as np - -import dask.array as da -from dask.base import key_split, visualize - - def test_reduce_with_many_common_dependents(): ndeps = 3 @@ -1750,20 +1708,4 @@ def random(**kwargs): for dep in dependencies[r]: prios_deps.append(o[dep]) drift[r] = (min(prios_deps), max(prios_deps)) - # assert max(prios_deps) - min(prios_deps) == len(dependencies[r]) - 1 - - print(f"{drift=}") - from dask.base import visualize - - visualize( - collections_to_dsk([graph]), - filename="test_decide_worker_coschedule_order_neighbors-color.png", - optimize_graph=True, - color="order", - o=o, - ) - visualize( - collections_to_dsk([graph]), - filename="test_decide_worker_coschedule_order_neighbors.png", - optimize_graph=True, - ) + assert max(prios_deps) - min(prios_deps) == len(dependencies[r]) - 1 From d7493b9d870cb78bd75d4fcbc0bdcccf669ffb34 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 12 Oct 2023 18:33:48 +0200 Subject: [PATCH 3/8] Only reset layers_loaded after init_stack is empty --- dask/order.py | 15 +++++++-------- dask/tests/test_order.py | 8 ++++++++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/dask/order.py b/dask/order.py index 3047c77e05b..373f393224e 100644 --- a/dask/order.py +++ b/dask/order.py @@ -246,7 +246,6 @@ def process_runnables(layers_loaded: int = 0) -> None: ) while runnable_sorted: task = runnable_sorted.pop() - if task in runnable: if ( len(set_difference(dependents[runnable[task]], result)) @@ -254,7 +253,6 @@ def process_runnables(layers_loaded: int = 0) -> None: ): next_nodes[partition_keys[task]].add(task) continue - print(f"Set result from runnable {task=} {i=}") result[task] = i i += 1 deps = dependents[task] @@ -264,10 +262,9 @@ def process_runnables(layers_loaded: int = 0) -> None: runnable_sorted.append(dep) else: next_nodes[partition_keys[dep]].add(dep) - # runnable.clear() - dep_pools = defaultdict(set) layers_loaded = 0 + dep_pools = defaultdict(set) while True: while inner_stack: item = inner_stack_pop() @@ -285,8 +282,6 @@ def process_runnables(layers_loaded: int = 0) -> None: process_runnables(layers_loaded) layers_loaded += 1 continue - layers_loaded = 0 - print(f"Set result from inner_stack {item=} {i=}") result[item] = i i += 1 deps = dependents[item] @@ -302,6 +297,8 @@ def process_runnables(layers_loaded: int = 0) -> None: # Heap? all_keys = [] for dep in deps: + if dep in seen: + continue pkey = partition_keys[dep] dep_pools[pkey].add(dep) all_keys.append(pkey) @@ -315,13 +312,15 @@ def process_runnables(layers_loaded: int = 0) -> None: inner_stack = list(dep_pools[pkey]) inner_stack_pop = inner_stack.pop seen_update(inner_stack) + continue else: - pass + next_nodes[pkey].update(dep_pools[pkey]) else: next_nodes[pkey].update(dep_pools[pkey]) dep_pools.clear() - process_runnables() + process_runnables(layers_loaded) + layers_loaded = 0 if next_nodes and not inner_stack: min_key = min(next_nodes) diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index 67ab8ca157c..af285c58e69 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -661,6 +661,7 @@ def test_order_empty(): assert order({}) == {} +@pytest.mark.xfail(reason="Why is `cde` a better path? Why even start at a0?") def test_switching_dependents(abcde): r""" @@ -1014,6 +1015,13 @@ def test_diagnostics(abcde): (e, 1): (f, (e, 0)), } o = order(dsk) + from dask.base import visualize + + visualize( + dsk, + filename="test_diagnostics", + ) + visualize(dsk, o=o, filename="test_diagnostics-order", color="order") assert o[(e, 1)] == len(dsk) - 1 assert o[(d, 1)] == len(dsk) - 2 assert o[(c, 1)] == len(dsk) - 3 From 638734422aad5dd895c02fbd97d4a48b2540bae0 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 13 Oct 2023 08:39:22 +0200 Subject: [PATCH 4/8] remove numpy import --- dask/tests/test_order.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index af285c58e69..ba1ed410670 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -1,10 +1,8 @@ from __future__ import annotations -import numpy as np import pytest import dask -import dask.array as da from dask import delayed from dask.base import collections_to_dsk, key_split from dask.core import get_deps @@ -1682,6 +1680,9 @@ def test_flox_reduction(): def test_reduce_with_many_common_dependents(): + da = pytest.importorskip("dask.array") + import numpy as np + ndeps = 3 def random(**kwargs): From f7f1e42863d29f9906920e3ce6b23e1cc4459061 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 13 Oct 2023 08:40:03 +0200 Subject: [PATCH 5/8] revert visualize --- dask/base.py | 5 ++--- dask/tests/test_order.py | 6 ------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/dask/base.py b/dask/base.py index b2bdafe7b42..bde015b77b1 100644 --- a/dask/base.py +++ b/dask/base.py @@ -719,10 +719,9 @@ def visualize( https://docs.dask.org/en/latest/optimize.html """ - dsk = args[0] - # args, _ = unpack_collections(*args, traverse=traverse) + args, _ = unpack_collections(*args, traverse=traverse) - # dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph)) + dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph)) color = kwargs.get("color") diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index ba1ed410670..4cc20eb8762 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -1013,13 +1013,7 @@ def test_diagnostics(abcde): (e, 1): (f, (e, 0)), } o = order(dsk) - from dask.base import visualize - visualize( - dsk, - filename="test_diagnostics", - ) - visualize(dsk, o=o, filename="test_diagnostics-order", color="order") assert o[(e, 1)] == len(dsk) - 1 assert o[(d, 1)] == len(dsk) - 2 assert o[(c, 1)] == len(dsk) - 3 From dcd17f123aee813b7d5defeb2d3ab109661fb849 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 13 Oct 2023 10:23:06 +0200 Subject: [PATCH 6/8] perf fixes --- dask/order.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/dask/order.py b/dask/order.py index 373f393224e..6119da20e99 100644 --- a/dask/order.py +++ b/dask/order.py @@ -227,6 +227,7 @@ def dependencies_key(x: Key) -> tuple: max_heights, ) in metrics.items() } + pkey_getitem = partition_keys.__getitem__ result: dict[Key, int] = {root: len(dsk) - 1} i = 0 @@ -241,9 +242,7 @@ def process_runnables(layers_loaded: int = 0) -> None: runnable_candidates = set_difference( set_difference(set(runnable), result), seen ) - runnable_sorted = sorted( - runnable_candidates, key=partition_keys.__getitem__, reverse=True - ) + runnable_sorted = sorted(runnable_candidates, key=pkey_getitem, reverse=True) while runnable_sorted: task = runnable_sorted.pop() if task in runnable: @@ -251,7 +250,7 @@ def process_runnables(layers_loaded: int = 0) -> None: len(set_difference(dependents[runnable[task]], result)) > 1 + layers_loaded ): - next_nodes[partition_keys[task]].add(task) + next_nodes[pkey_getitem(task)].add(task) continue result[task] = i i += 1 @@ -261,7 +260,7 @@ def process_runnables(layers_loaded: int = 0) -> None: if not num_needed[dep]: runnable_sorted.append(dep) else: - next_nodes[partition_keys[dep]].add(dep) + next_nodes[pkey_getitem(dep)].add(dep) layers_loaded = 0 dep_pools = defaultdict(set) @@ -299,14 +298,14 @@ def process_runnables(layers_loaded: int = 0) -> None: for dep in deps: if dep in seen: continue - pkey = partition_keys[dep] + pkey = pkey_getitem(dep) dep_pools[pkey].add(dep) all_keys.append(pkey) all_keys.sort() target_key: tuple[int, ...] | None = None for pkey in reversed(all_keys): if inner_stack: - target_key = target_key or partition_keys[inner_stack[0]] + target_key = target_key or pkey_getitem(inner_stack[0]) if pkey < target_key: next_nodes[target_key].update(inner_stack) inner_stack = list(dep_pools[pkey]) @@ -331,12 +330,6 @@ def process_runnables(layers_loaded: int = 0) -> None: seen_update(inner_stack) continue - # This is just for perf reasons to cut the set differences down - for k in list(runnable): - if k in result: - del runnable[k] - seen = seen - set(result) - if inner_stack: continue @@ -344,7 +337,6 @@ def process_runnables(layers_loaded: int = 0) -> None: break if not is_init_sorted: - # TODO: Is this even worth it? init_stack = set(init_stack) init_stack = set_difference(init_stack, result) if len(init_stack) < 10000: @@ -352,9 +344,8 @@ def process_runnables(layers_loaded: int = 0) -> None: else: init_stack = list(init_stack) is_init_sorted = True - assert is_init_sorted - assert isinstance(init_stack, list) - inner_stack = [init_stack.pop()] + + inner_stack = [init_stack.pop()] # type: ignore[call-overload] inner_stack_pop = inner_stack.pop return result From f297ad47f060ff93e73867e273057ce88729b3c2 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 13 Oct 2023 10:53:58 +0200 Subject: [PATCH 7/8] more perf fixes --- dask/order.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/dask/order.py b/dask/order.py index 6119da20e99..59f15fdaa9d 100644 --- a/dask/order.py +++ b/dask/order.py @@ -79,6 +79,7 @@ """ from collections import defaultdict, namedtuple from collections.abc import Mapping, MutableMapping +from heapq import heappop, heappush from typing import Any, cast from dask.core import get_dependencies, get_deps, getcycle, istask, reverse_dict @@ -234,14 +235,13 @@ def dependencies_key(x: Key) -> tuple: inner_stack = [min(init_stack, key=initial_stack_key)] inner_stack_pop = inner_stack.pop next_nodes: defaultdict[tuple[int, ...], set[Key]] = defaultdict(set) + min_key_next_nodes: list[tuple[int, ...]] = [] runnable: dict[Key, Key] = dict() set_difference = set.difference def process_runnables(layers_loaded: int = 0) -> None: nonlocal i - runnable_candidates = set_difference( - set_difference(set(runnable), result), seen - ) + runnable_candidates = set_difference(set(runnable), seen) runnable_sorted = sorted(runnable_candidates, key=pkey_getitem, reverse=True) while runnable_sorted: task = runnable_sorted.pop() @@ -250,9 +250,12 @@ def process_runnables(layers_loaded: int = 0) -> None: len(set_difference(dependents[runnable[task]], result)) > 1 + layers_loaded ): - next_nodes[pkey_getitem(task)].add(task) + pkey = pkey_getitem(task) + heappush(min_key_next_nodes, pkey) + next_nodes[pkey].add(task) continue result[task] = i + runnable.pop(task, None) i += 1 deps = dependents[task] for dep in deps: @@ -260,7 +263,9 @@ def process_runnables(layers_loaded: int = 0) -> None: if not num_needed[dep]: runnable_sorted.append(dep) else: - next_nodes[pkey_getitem(dep)].add(dep) + pkey = pkey_getitem(dep) + heappush(min_key_next_nodes, pkey) + next_nodes[pkey].add(dep) layers_loaded = 0 dep_pools = defaultdict(set) @@ -282,6 +287,7 @@ def process_runnables(layers_loaded: int = 0) -> None: layers_loaded += 1 continue result[item] = i + runnable.pop(item, None) i += 1 deps = dependents[item] for dep in deps: @@ -308,21 +314,24 @@ def process_runnables(layers_loaded: int = 0) -> None: target_key = target_key or pkey_getitem(inner_stack[0]) if pkey < target_key: next_nodes[target_key].update(inner_stack) + heappush(min_key_next_nodes, target_key) inner_stack = list(dep_pools[pkey]) inner_stack_pop = inner_stack.pop seen_update(inner_stack) continue - else: - next_nodes[pkey].update(dep_pools[pkey]) - else: - next_nodes[pkey].update(dep_pools[pkey]) + next_nodes[pkey].update(dep_pools[pkey]) + heappush(min_key_next_nodes, pkey) + dep_pools.clear() process_runnables(layers_loaded) layers_loaded = 0 if next_nodes and not inner_stack: - min_key = min(next_nodes) + # there may be duplicates on the heap + min_key = heappop(min_key_next_nodes) + while min_key not in next_nodes: + min_key = heappop(min_key_next_nodes) inner_stack = sorted( next_nodes.pop(min_key), key=dependents_key, reverse=True ) From aaf3f1f2a795ff46399c9bc23a78779360627536 Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 16 Oct 2023 16:22:50 +0200 Subject: [PATCH 8/8] Prototype for assignment groups --- dask/base.py | 17 +++++++++--- dask/order.py | 60 ++++++++++++++++++++++++++++++++++------ dask/tests/test_order.py | 50 ++++++++++++++++++++++++++++++++- 3 files changed, 113 insertions(+), 14 deletions(-) diff --git a/dask/base.py b/dask/base.py index bde015b77b1..d6c356c9b66 100644 --- a/dask/base.py +++ b/dask/base.py @@ -719,9 +719,10 @@ def visualize( https://docs.dask.org/en/latest/optimize.html """ - args, _ = unpack_collections(*args, traverse=traverse) + dsk = args[0] + # args, _ = unpack_collections(*args, traverse=traverse) - dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph)) + # dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph)) color = kwargs.get("color") @@ -737,13 +738,17 @@ def visualize( "memoryincreases", "memorydecreases", "memorypressure", + "group", + "order-group", }: import matplotlib.pyplot as plt from dask.order import diagnostics, order - if o is None: - o = order(dsk) + if "group" in color: + o, groups = order(dsk, group=True) + elif o is None: + o = order(dsk, group=False) try: cmap = kwargs.pop("cmap") except KeyError: @@ -773,6 +778,10 @@ def label(x): key: max(0, val.num_data_when_released - val.num_data_when_run) for key, val in info.items() } + elif color.endswith("group"): + values = { + key: group_ix for group_ix, keys in groups.items() for key in keys + } else: # memorydecreases values = { key: max(0, val.num_data_when_run - val.num_data_when_released) diff --git a/dask/order.py b/dask/order.py index 59f15fdaa9d..d420d5e1a42 100644 --- a/dask/order.py +++ b/dask/order.py @@ -80,19 +80,40 @@ from collections import defaultdict, namedtuple from collections.abc import Mapping, MutableMapping from heapq import heappop, heappush -from typing import Any, cast +from typing import Any, Literal, cast, overload from dask.core import get_dependencies, get_deps, getcycle, istask, reverse_dict from dask.typing import Key +@overload def order( dsk: MutableMapping[Key, Any], - dependencies: MutableMapping[Key, set[Key]] | None = None, + dependencies: MutableMapping[Key, set[Key]] | None, + group: Literal[False], ) -> dict[Key, int]: + ... + + +@overload +def order( + dsk: MutableMapping[Key, Any], + dependencies: MutableMapping[Key, set[Key]] | None, + group: Literal[True], +) -> tuple[dict[Key, int], dict[int, list[Key]]]: + ... + + +def order( + dsk: MutableMapping[Key, Any], + dependencies: MutableMapping[Key, set[Key]] | None = None, + group: bool = False, +) -> dict[Key, int] | tuple[dict[Key, int], dict[int, list[Key]]]: if not dsk: return {} - + groups = defaultdict(list) + groups_by_key = dict() + group_ix = 0 dsk = dict(dsk) if dependencies is None: @@ -126,9 +147,14 @@ def _f(*args: Any, **kwargs: Any) -> None: dsk[root] = (_f, *root_nodes) dependencies[root] = root_nodes - o = order(dsk, dependencies) - del o[root] - return o + if not group: + o = order(dsk, dependencies, group=group) + del o[root] + return o + else: + o, g = order(dsk, dependencies, group=group) + del o[root] + return o, g root = list(root_nodes)[0] init_stack: dict[Key, tuple] | set[Key] | list[Key] # Leaf nodes. We choose one--the initial node--for each weakly connected subgraph. @@ -243,6 +269,7 @@ def process_runnables(layers_loaded: int = 0) -> None: nonlocal i runnable_candidates = set_difference(set(runnable), seen) runnable_sorted = sorted(runnable_candidates, key=pkey_getitem, reverse=True) + prev = None while runnable_sorted: task = runnable_sorted.pop() if task in runnable: @@ -255,6 +282,11 @@ def process_runnables(layers_loaded: int = 0) -> None: next_nodes[pkey].add(task) continue result[task] = i + # groups[group_ix].append(task) + group_ix = groups_by_key[runnable.get(task, prev)] + groups[group_ix].append(task) + groups_by_key[task] = group_ix + prev = task runnable.pop(task, None) i += 1 deps = dependents[task] @@ -287,6 +319,8 @@ def process_runnables(layers_loaded: int = 0) -> None: layers_loaded += 1 continue result[item] = i + groups[group_ix].append(item) + groups_by_key[item] = group_ix runnable.pop(item, None) i += 1 deps = dependents[item] @@ -318,6 +352,8 @@ def process_runnables(layers_loaded: int = 0) -> None: inner_stack = list(dep_pools[pkey]) inner_stack_pop = inner_stack.pop seen_update(inner_stack) + if group_ix in groups and len(groups[group_ix]) > 1: + group_ix += 1 continue next_nodes[pkey].update(dep_pools[pkey]) heappush(min_key_next_nodes, pkey) @@ -336,6 +372,8 @@ def process_runnables(layers_loaded: int = 0) -> None: next_nodes.pop(min_key), key=dependents_key, reverse=True ) inner_stack_pop = inner_stack.pop + if group_ix in groups and len(groups[group_ix]) > 1: + group_ix += 1 seen_update(inner_stack) continue @@ -344,7 +382,9 @@ def process_runnables(layers_loaded: int = 0) -> None: if len(result) == len(dsk): break - + # Increasing here is very conservative + if group_ix in groups and len(groups[group_ix]) > 1: + group_ix += 1 if not is_init_sorted: init_stack = set(init_stack) init_stack = set_difference(init_stack, result) @@ -356,8 +396,10 @@ def process_runnables(layers_loaded: int = 0) -> None: inner_stack = [init_stack.pop()] # type: ignore[call-overload] inner_stack_pop = inner_stack.pop - - return result + if group: + return result, dict(groups) + else: + return result def graph_metrics( diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index 4cc20eb8762..91a45630009 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -13,7 +13,7 @@ @pytest.fixture( params=[ "abcde", - "edcba", + # "edcba", ] ) def abcde(request): @@ -751,6 +751,12 @@ def test_order_with_equal_dependents(abcde): } ) o = order(dsk) + + import inspect + + from dask.base import visualize + + visualize(dsk, filename=inspect.stack()[0][3], color="group") total = 0 for x in abc: for i in range(len(abc)): @@ -857,6 +863,11 @@ def test_terminal_node_backtrack(): ), } o = order(dsk) + import inspect + + from dask.base import visualize + + visualize(dsk, filename=inspect.stack()[0][3], color="group") assert o[("a", 2)] < o[("a", 3)] @@ -876,6 +887,11 @@ def test_array_store_final_order(tmpdir): dest = root.empty_like(name="dest", data=x, chunks=x.chunksize, overwrite=True) d = x.store(dest, lock=False, compute=False) o = order(d.dask) + import inspect + + from dask.base import visualize + + visualize(d.dask, filename=inspect.stack()[0][3], color="group") # Find the lowest store. Dask starts here. stores = [k for k in o if isinstance(k, tuple) and k[0].startswith("store-map-")] first_store = min(stores, key=lambda k: o[k]) @@ -979,6 +995,11 @@ def test_eager_to_compute_dependent_to_free_parent(): } dependencies, dependents = get_deps(dsk) o = order(dsk) + import inspect + + from dask.base import visualize + + visualize(dsk, filename=inspect.stack()[0][3], color="group") parents = {deps.pop() for key, deps in dependents.items() if not dependencies[key]} def cost(deps): @@ -1013,7 +1034,11 @@ def test_diagnostics(abcde): (e, 1): (f, (e, 0)), } o = order(dsk) + import inspect + from dask.base import visualize + + visualize(dsk, filename=inspect.stack()[0][3], color="group") assert o[(e, 1)] == len(dsk) - 1 assert o[(d, 1)] == len(dsk) - 2 assert o[(c, 1)] == len(dsk) - 3 @@ -1135,10 +1160,23 @@ def test_array_vs_dataframe(optimize): quad = ds**2 quad["uv"] = ds.anom_u * ds.anom_v mean = quad.mean("time") + dsk = collections_to_dsk([mean], optimize_graph=optimize) + o, g = order(dsk, group=True) + print(len(g)) + print({ix: len(keys) for ix, keys in g.items()}) diag_array = diagnostics(collections_to_dsk([mean], optimize_graph=optimize)) diag_df = diagnostics( collections_to_dsk([mean.to_dask_dataframe()], optimize_graph=optimize) ) + import inspect + + from dask.base import visualize + + visualize( + collections_to_dsk([mean], optimize_graph=optimize), + filename=inspect.stack()[0][3], + color="group", + ) assert max(diag_df[1]) == max(diag_array[1]) assert max(diag_array[1]) < 50 @@ -1218,6 +1256,11 @@ def test_anom_mean_raw(): } o = order(dsk) + import inspect + + from dask.base import visualize + + visualize(dsk, filename=inspect.stack()[0][3], color="group") # The left hand computation branch should complete before we start loading # more data nodes_to_finish_before_loading_more_data = [ @@ -1701,6 +1744,11 @@ def random(**kwargs): dependencies, dependents = get_deps(dsk) # Verify assumptions o = order(dsk) + import inspect + + from dask.base import visualize + + visualize(dsk, filename=inspect.stack()[0][3], color="group") # Verify assumptions (specifically that the reducers are sum-aggregate) assert {key_split(k) for k in o} == {"object", "sum", "sum-aggregate"}