diff --git a/dask/order.py b/dask/order.py index 9a867866027..f5e50bfec8a 100644 --- a/dask/order.py +++ b/dask/order.py @@ -79,7 +79,7 @@ """ from collections import defaultdict, namedtuple from collections.abc import Mapping, MutableMapping -from math import log +from heapq import heappop, heappush, nsmallest from typing import Any, cast from dask.core import get_dependencies, get_deps, getcycle, istask, reverse_dict @@ -95,16 +95,8 @@ def order( This produces an ordering over our tasks that we use to break ties when executing. We do this ahead of time to reduce a bit of stress on the scheduler and also to assist in static analysis. - This currently traverses the graph as a single-threaded scheduler would - traverse it. It breaks ties in the following ways: - - 1. Begin at a leaf node that is a dependency of a root node that has the - largest subgraph (start hard things first) - 2. Prefer tall branches with few dependents (start hard things first and - try to avoid memory usage) - 3. Prefer dependents that are dependencies of root nodes that have - the smallest subgraph (do small goals that can terminate quickly) + traverse it. Examples -------- @@ -123,7 +115,7 @@ def order( dependencies = {k: get_dependencies(dsk, k) for k in dsk} dependents = reverse_dict(dependencies) num_needed, total_dependencies = ndependencies(dependencies, dependents) - metrics = graph_metrics(dependencies, dependents, total_dependencies) + metrics = graph_metrics(dependencies, dependents) if len(metrics) != len(dsk): cycle = getcycle(dsk, None) @@ -132,13 +124,6 @@ def order( % "\n -> ".join(str(x) for x in cycle) ) - # Single root nodes that depend on everything. These cause issues for - # the current ordering algorithm, since we often hit the root node - # and fell back to the key tie-breaker to choose which immediate dependency - # to finish next, rather than finishing off subtrees. - # So under the special case of a single root node that depends on the entire - # tree, we skip processing it normally. - # See https://github.com/dask/dask/issues/6745 root_nodes = {k for k, v in dependents.items() if not v} if len(root_nodes) > 1: # This is also nice because it makes us robust to difference when @@ -153,7 +138,7 @@ def _f(*args: Any, **kwargs: Any) -> None: o = order(dsk, dependencies) del o[root] return o - + root = list(root_nodes)[0] init_stack: dict[Key, tuple] | set[Key] | list[Key] # Leaf nodes. We choose one--the initial node--for each weakly connected subgraph. # Let's calculate the `initial_stack_key` as we determine `init_stack` set. @@ -172,8 +157,6 @@ def _f(*args: Any, **kwargs: Any) -> None: ) for key, num_dependents, ( total_dependents, - _, - _, min_heights, max_heights, ) in ( @@ -182,6 +165,7 @@ def _f(*args: Any, **kwargs: Any) -> None: if not val ) } + is_init_sorted = False # `initial_stack_key` chooses which task to run at the very beginning. # This value is static, so we pre-compute as the value of this dict. initial_stack_key = init_stack.__getitem__ @@ -200,7 +184,7 @@ def dependents_key(x: Key) -> tuple: # Do we favor deep or shallow branches? # -1: deep # +1: shallow - -metrics[x][3], # min_heights + -metrics[x][1], # min_heights # tie-breaker StrComparable(x), ) @@ -214,8 +198,6 @@ def dependencies_key(x: Key) -> tuple: num_dependents = len(dependents[x]) ( total_dependents, - _, - _, min_heights, max_heights, ) = metrics[x] @@ -233,540 +215,202 @@ def dependencies_key(x: Key) -> tuple: StrComparable(x), ) + seen = set(root_nodes) + seen_update = seen.update root_total_dependencies = total_dependencies[list(root_nodes)[0]] - # Computing this for all keys can sometimes be relatively expensive :( partition_keys = { key: ( - (root_total_dependencies - total_dependencies[key] + 1) - * (total_dependents - min_heights) + (root_total_dependencies - total_dependencies[key] + 1), + (total_dependents - min_heights), + -max_heights, ) for key, ( total_dependents, - _, - _, min_heights, - _, + max_heights, ) in metrics.items() } - - result: dict[Key, int] = {} + result: dict[Key, int] = {root: len(dsk) - 1} i = 0 - # `inner_stack` is used to perform a DFS along dependencies. Once emptied - # (when traversing dependencies), this continue down a path along dependents - # until a root node is reached. - # - # Sometimes, a better path along a dependent is discovered (i.e., something - # that is easier to compute and doesn't requiring holding too much in memory). - # In this case, the current `inner_stack` is appended to `inner_stacks` and - # we begin a new DFS from the better node. - # - # A "better path" is determined by comparing `partition_keys`. inner_stack = [min(init_stack, key=initial_stack_key)] inner_stack_pop = inner_stack.pop - inner_stacks: list[list[Key]] = [] - inner_stacks_append = inner_stacks.append - inner_stacks_extend = inner_stacks.extend - inner_stacks_pop = inner_stacks.pop - - # Okay, now we get to the data structures used for fancy behavior. - # - # As we traverse nodes in the DFS along dependencies, we partition the dependents - # via `partition_key`. A dependent goes to: - # 1) `inner_stack` if it's better than our current target, - # 2) `next_nodes` if the partition key is lower than it's parent, - # When the inner stacks are depleted, we process `next_nodes`. - # These dicts use `partition_keys` as keys. We process them by placing the values - # in `outer_stack` so that the smallest keys will be processed first. - next_nodes: defaultdict[int, list[list[Key] | set[Key]]] = defaultdict(list) - - # `outer_stack` is used to populate `inner_stacks`. From the time we partition the - # dependents of a node, we group them: one list per partition key per parent node. - # This likely results in many small lists. We do this to avoid sorting many larger - # lists (i.e., to avoid n*log(n) behavior). So, we have many small lists that we - # partitioned, and we keep them in the order that we saw them (we will process them - # in a FIFO manner). By delaying sorting for as long as we can, we can first filter - # out nodes that have already been computed. All this complexity is worth it! - outer_stack: list[list[Key]] = [] - outer_stack_extend = outer_stack.extend - outer_stack_pop = outer_stack.pop - - # Keep track of nodes that are in `inner_stack` or `inner_stacks` so we don't - # process them again. - seen = set(root_nodes) - seen_update = seen.update - seen_add = seen.add - - # "singles" are tasks that are available to run, and when run may free a dependency. - # Although running a task to free a dependency may seem like a wash (net zero), it - # can be beneficial by providing more opportunities for a later task to free even - # more data. So, it costs us little in the short term to more eagerly compute - # chains of tasks that keep the same number of data in memory, and the longer term - # rewards are potentially high. I would expect a dynamic scheduler to have similar - # behavior, so I think it makes sense to do the same thing here in `dask.order`. - # - # When we gather tasks in `singles`, we do so optimistically: running the task *may* - # free the parent, but it also may not, because other dependents of the parent may - # be in the inner stacks. When we process `singles`, we run tasks that *will* free - # the parent, otherwise we move the task to `later_singles`. `later_singles` is run - # when there are no inner stacks, so it is safe to run all of them (because no other - # dependents will be hiding in the inner stacks to keep hold of the parent). - # `singles` is processed when the current item on the stack needs to compute - # dependencies before it can be run. - # - # Processing singles is meant to be a detour. Doing so should not change our - # tactical goal in most cases. Hence, we set `add_to_inner_stack = False`. - # - # In reality, this is a pretty limited strategy for running a task to free a - # dependency. A thorough strategy would be to check whether running a dependent - # with `num_needed[dep] == 0` would free *any* of its dependencies. This isn't - # what we do. This data isn't readily or cheaply available. We only check whether - # it will free its last dependency that was computed (the current `item`). This is - # probably okay. In general, our tactics and strategies for ordering try to be - # memory efficient, so we shouldn't try too hard to work around what we already do. - # However, sometimes the DFS nature of it leaves "easy-to-compute" stragglers behind. - # The current approach is very fast to compute, can be beneficial, and is generally - # low-risk. There could be more we could do here though. Our static scheduling - # here primarily looks at "what dependent should we run next?" instead of "what - # dependency should we try to free?" Two sides to the same question, but a dynamic - # scheduler is much better able to answer the latter one, because it knows the size - # of data and can react to current state. Does adding a little more dynamic-like - # behavior to `dask.order` add any tension to running with an actual dynamic - # scheduler? Should we defer to dynamic schedulers and let them behave like this - # if they so choose? Maybe. However, I'm sensitive to the multithreaded scheduler, - # which is heavily dependent on the ordering obtained here. - singles: dict[Key, Key] = {} - singles_clear = singles.clear - later_singles: list[Key] = [] - later_singles_append = later_singles.append - later_singles_clear = later_singles.clear - - # Priority of being processed - # 1. inner_stack - # 2. singles (may be moved to later_singles) - # 3. inner_stacks - # 4. later_singles - # 5. next_nodes - # 6. outer_stack - # 7. init_stack - - # alias for speed - set_difference = set.difference - - is_init_sorted = False - + next_nodes: defaultdict[tuple[int, ...], set[Key]] = defaultdict(set) + in_next_nodes: set[Key] = set() + min_key_next_nodes: list[tuple[int, ...]] = [] + runnable_by_parent: defaultdict[Key, set[Key]] = defaultdict(set) + + def process_runnables(layers_loaded: int) -> None: + nonlocal i + # Sort by number of dependents such that we process parents with few dependents first. + # This is a performance optimization that allows us to break the for + # loop early if we find a parent that is not allowed to proceed. This is + # merely an assumption that is not generally true but has been proven to + # be effective in practice. + for parent, runnable_tasks in sorted( + runnable_by_parent.items(), key=lambda x: len(dependents[x[0]]) + ): + pkey = partition_keys[parent] + deps_parent = dependents[parent] + deps_not_in_result = deps_parent.difference(result) + # We only want to process nodes that guarantee to release the + # parent, i.e. len(deps_not_in_result) == 1 + # However, the more aggressively the DFS has to backtrack, the more + # eagerly we are willing to process other runnable tasks to release + # as many parents as possible before loading more data (which + # typically happens when backtracking). + if len(deps_not_in_result) > 1 + layers_loaded: + new_tasks = runnable_tasks - in_next_nodes + if new_tasks: + heappush(min_key_next_nodes, pkey) + next_nodes[pkey].update(new_tasks) + in_next_nodes.update(new_tasks) + break + del runnable_by_parent[parent] + runnable_candidates = runnable_tasks - seen + runnable_sorted = sorted( + runnable_candidates, key=partition_keys.__getitem__, reverse=True + ) + while runnable_sorted: + task = runnable_sorted.pop() + result[task] = i + i += 1 + deps = dependents[task] + for dep in deps: + num_needed[dep] -= 1 + if not num_needed[dep]: + runnable_sorted.append(dep) + elif dep not in in_next_nodes: + pkey = partition_keys[dep] + heappush(min_key_next_nodes, pkey) + next_nodes[pkey].add(dep) + + layers_loaded = 0 + dep_pools = defaultdict(set) while True: - while True: - # Perform a DFS along dependencies until we complete our tactical goal - deps = set() - add_to_inner_stack = True - if inner_stack: - item = inner_stack_pop() - if item in result: - continue - if num_needed[item]: - if item not in root_nodes: - inner_stack.append(item) - deps = set_difference(dependencies[item], result) - if 1 < len(deps) < 1000: - inner_stack.extend( - sorted(deps, key=dependencies_key, reverse=True) - ) - else: - inner_stack.extend(deps) - seen_update(deps) - if not singles: - continue - # Only process singles once the inner_stack is fully - # resolved. This is important because the singles path later - # on verifies that running the single indeed opens an - # opportunity to release soon by comparing the singles - # parent's dependents with the inner_stack(s) - if inner_stack and num_needed[inner_stack[-1]]: - continue - process_singles = True + while inner_stack: + item = inner_stack_pop() + if item in result: + continue + if num_needed[item]: + inner_stack.append(item) + deps = dependencies[item].difference(result) + if 1 < len(deps) < 1000: + inner_stack.extend(sorted(deps, key=dependencies_key, reverse=True)) else: - result[item] = i - i += 1 - deps = dependents[item] - add_to_inner_stack = True - - if deps: - for dep in deps: - num_needed[dep] -= 1 - process_singles = False - else: - continue - elif inner_stacks: - inner_stack = inner_stacks_pop() - inner_stack_pop = inner_stack.pop + inner_stack.extend(deps) + seen_update(deps) + if not num_needed[inner_stack[-1]]: + process_runnables(layers_loaded) + layers_loaded += 1 continue - elif singles: - process_singles = True - elif later_singles: - # No need to be optimistic: all nodes in `later_singles` will free a dependency - # when run, so no need to check whether dependents are in `seen`. - for single in later_singles: - if single in result: - continue - while True: - deps_singles = dependents[single] - result[single] = i - i += 1 - if deps_singles: - for dep in deps_singles: - num_needed[dep] -= 1 - if len(deps_singles) == 1: - # Fast path! We trim down `dep2` above hoping to reach here. - (single,) = deps_singles - if not num_needed[single]: - # Keep it going! - deps_singles = dependents[single] - continue - deps |= deps_singles - del deps_singles - break - later_singles_clear() - deps = set_difference(deps, result) - if not deps: - continue - add_to_inner_stack = False - process_singles = True - else: - break + result[item] = i + i += 1 + deps = dependents[item] + all_keys = [] + target_key = None + if inner_stack: + target_key = partition_keys[inner_stack[0]] + for dep in deps: + num_needed[dep] -= 1 + if not num_needed[dep]: + runnable_by_parent[item].add(dep) - if process_singles and singles: - # We gather all dependents of all singles into `deps`, which we then process below. - - add_to_inner_stack = True if inner_stack or inner_stacks else False - singles_keys = set_difference(set(singles), result) - - # NOTE: If this was too slow, LIFO would be a decent - # approximation - for single in sorted(singles_keys, key=lambda x: partition_keys[x]): - # We want to run the singles if they are either releasing a - # dependency directly or that they may be releasing a - # dependency once the current critical path / inner_stack is - # walked. - # By using `seen` here this is more permissive since it also - # includes tasks in a future critical path / inner_stacks - # but it would require additional state to make this - # distinction and we don't have enough data to dermine if - # this is worth it. - parent = singles[single] - if ( - len( - set_difference( - set_difference(dependents[parent], result), - seen, - ) - ) - > 1 - ): - later_singles_append(single) - continue - while True: - deps_singles = dependents[single] - result[single] = i - i += 1 - if deps_singles: - for dep in deps_singles: - num_needed[dep] -= 1 - if add_to_inner_stack: - already_seen = deps_singles & seen - if already_seen: - # This means that the singles path also - # leads to the current or previous strategic - # path - if len(deps_singles) == len(already_seen): - if len(already_seen) == 1: - (single,) = already_seen - if not num_needed[single]: - deps_singles = dependents[single] - continue - break - deps_singles = deps_singles - already_seen - else: - already_seen = set() - if len(deps_singles) == 1: - # Fast path! We trim down `dep2` above hoping to reach here. - (single,) = deps_singles - if not num_needed[single]: - if not already_seen: - # Keep it going! - deps_singles = dependents[single] - continue - later_singles_append(single) - break - deps |= deps_singles - del deps_singles - break - del singles_keys - deps = set_difference(deps, result) - singles_clear() - if not deps: + if dep in seen: continue - add_to_inner_stack = False - - # If inner_stack is empty, then we typically add the best dependent to it. - # However, we don't add to it if a dependent is already on an inner_stack. In this case, we add the - # dependents (not in an inner_stack) to next_nodes or later_nodes to handle later. - # This serves three purposes: - # 1. shrink `deps` so that it can be processed faster, - # 2. make sure we don't process the same dependency repeatedly, and - # 3. make sure we don't accidentally continue down an expensive-to-compute path. - already_seen = deps & seen - if already_seen: - if len(deps) == len(already_seen): - if len(already_seen) == 1: - (dep,) = already_seen - if not num_needed[dep]: - singles[dep] = item - del dep - continue - add_to_inner_stack = False - deps = deps - already_seen - - if len(deps) == 1: - # Fast path! We trim down `deps` above hoping to reach here. - (dep,) = deps - if add_to_inner_stack and not inner_stack: - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - continue - key = partition_keys[dep] - if not num_needed[dep]: - # We didn't put the single dependency on the stack, but we should still - # run it soon, because doing so may free its parent. - singles[dep] = item - else: - next_nodes[key].append(deps) - del dep, key - elif len(deps) == 2: - # We special-case when len(deps) == 2 so that we may place a dep on singles. - # Otherwise, the logic here is the same as when `len(deps) > 2` below. - # - # Let me explain why this is a special case. If we put the better dependent - # onto the inner stack, then it's guaranteed to run next. After it's run, - # then running the other dependent *may* allow their parent to be freed. - dep, dep2 = deps - key = partition_keys[dep] - key2 = partition_keys[dep2] - if ( - key2 < key - or key == key2 - and dependents_key(dep2) < dependents_key(dep) - ): - dep, dep2 = dep2, dep - key, key2 = key2, key - if inner_stack: - prev_key = partition_keys[inner_stack[0]] - if key2 < prev_key: - inner_stacks_append(inner_stack) - inner_stacks_append([dep2]) - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_update(deps) - if not num_needed[dep2]: - if process_singles: - later_singles_append(dep2) - else: - singles[dep2] = item - elif key < prev_key: - inner_stacks_append(inner_stack) - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - if not num_needed[dep2]: - if process_singles: - later_singles_append(dep2) - else: - singles[dep2] = item - else: - next_nodes[key2].append([dep2]) - else: - item_key = partition_keys[item] - for k, d in [(key, dep), (key2, dep2)]: - if not num_needed[d]: - if process_singles: - later_singles_append(d) - else: - singles[d] = item - else: - next_nodes[k].append([d]) - del item_key - del prev_key - else: - assert not inner_stack - if add_to_inner_stack: - inner_stack = [dep] - inner_stack_pop = inner_stack.pop - seen_add(dep) - if not num_needed[dep2]: - singles[dep2] = item - elif key == key2 and 5 * partition_keys[item] > 22 * key: - inner_stacks_append([dep2]) - seen_add(dep2) - else: - next_nodes[key2].append([dep2]) - else: - for k, d in [(key, dep), (key2, dep2)]: - next_nodes[k].append([d]) - del dep, dep2, key, key2 - else: - # Slow path :(. This requires grouping by partition_key. - dep_pools = defaultdict(set) - possible_singles = defaultdict(set) - for dep in deps: - pkey = partition_keys[dep] - if not num_needed[dep] and not process_singles: - possible_singles[pkey].add(dep) + pkey = partition_keys[dep] + all_keys.append(pkey) + + if not all_keys: + continue + + all_keys.sort() + change_target_key: tuple[int, ...] | None = None + if target_key is not None and all_keys[0] < target_key: + change_target_key = all_keys[0] + + new_stack = [] + for dep in deps: + pkey = partition_keys[dep] + if pkey == change_target_key: + new_stack.append(dep) + elif dep not in in_next_nodes: dep_pools[pkey].add(dep) - item_key = partition_keys[item] - if inner_stack: - # If we have an inner_stack, we need to look for a "better" path - prev_key = partition_keys[inner_stack[0]] - now_keys = [] # < inner_stack[0] - psingles = set() - for key, vals in dep_pools.items(): - if key < prev_key: - now_keys.append(key) - else: - psingles = possible_singles[key] - for s in psingles: - singles[s] = item - vals -= psingles - next_nodes[key].append(vals) - del vals, key - del psingles - if now_keys: - # Run before `inner_stack` (change tactical goal!) - inner_stacks_append(inner_stack) - if 1 < len(now_keys): - now_keys.sort(reverse=True) - for key in now_keys: - pool: set[Key] | list[Key] - pool = dep_pools[key] - if 1 < len(pool) < 100: - pool = sorted(pool, key=dependents_key, reverse=True) - inner_stacks_extend([dep] for dep in pool) - seen_update(pool) - del pool - inner_stack = inner_stacks_pop() - inner_stack_pop = inner_stack.pop - del now_keys, prev_key - else: - # If we don't have an inner_stack, then we don't need to look - # for a "better" path, but we do need traverse along dependents. - if add_to_inner_stack: - min_pool: list[Key] | set[Key] - min_key = min(dep_pools) - min_pool = dep_pools.pop(min_key) - if len(min_pool) == 1: - inner_stack = list(min_pool) - seen_update(inner_stack) - elif ( - 10 * item_key > 11 * len(min_pool) * len(min_pool) * min_key - ): - # Put all items in min_pool onto inner_stacks. - # I know this is a weird comparison. Hear me out. - # Although it is often beneficial to put all of the items in `min_pool` - # onto `inner_stacks` to process next, it is very easy to be overzealous. - # Sometimes it is actually better to defer until `next_nodes` is handled. - # We should only put items onto `inner_stacks` that we're reasonably - # confident about. The above formula is a best effort heuristic given - # what we have easily available. It is obviously very specific to our - # choice of partition_key. Dask tests take this route about 40%. - if len(min_pool) < 100: - min_pool = sorted( - min_pool, key=dependents_key, reverse=True - ) - inner_stacks_extend([dep] for dep in min_pool) - inner_stack = inner_stacks_pop() - seen_update(min_pool) - else: - # Put one item in min_pool onto inner_stack and the rest into next_nodes. - if len(min_pool) < 100: - inner_stack = [min(min_pool, key=dependents_key)] - else: - inner_stack = [min_pool.pop()] - next_nodes[min_key].append(min_pool) - seen_update(inner_stack) - del min_pool, min_key - inner_stack_pop = inner_stack.pop - for key, vals in dep_pools.items(): - psingles = possible_singles[key] - for s in psingles: - singles[s] = item - vals -= psingles - next_nodes[key].append(vals) - del key, vals - - if len(dependencies) == len(result): - break # all done! - - if next_nodes: - for key in sorted(next_nodes, reverse=True): - # `outer_stacks` may not be empty here--it has data from previous `next_nodes`. - # Since we pop things off of it (onto `inner_nodes`), this means we handle - # multiple `next_nodes` in a LIFO manner. - outer_stack_extend(list(el) for el in reversed(next_nodes[key])) - next_nodes.clear() - - outer_deps = [] - while outer_stack: - # Try to add a few items to `inner_stacks` - outer_deps = [x for x in outer_stack_pop() if x not in result] - if outer_deps: - if 1 < len(outer_deps) < 100: - outer_deps.sort(key=dependents_key, reverse=True) - inner_stacks_extend([dep] for dep in outer_deps) - seen_update(outer_deps) - break - del outer_deps - if inner_stacks: + if new_stack: + assert change_target_key is not None + assert target_key is not None + next_nodes[target_key].update(inner_stack) + heappush(min_key_next_nodes, target_key) + inner_stack = sorted(new_stack, key=dependents_key, reverse=True) + inner_stack_pop = inner_stack.pop + seen_update(inner_stack) + + for pkey in reversed(all_keys): + next_nodes[pkey].update(dep_pools[pkey]) + in_next_nodes.update(dep_pools[pkey]) + heappush(min_key_next_nodes, pkey) + + dep_pools.clear() + + process_runnables(layers_loaded) + layers_loaded = 0 + + if next_nodes and not inner_stack: + # there may be duplicates on the heap + min_key = heappop(min_key_next_nodes) + while min_key not in next_nodes: + min_key = heappop(min_key_next_nodes) + next_stack = next_nodes.pop(min_key) + next_stack = next_stack.difference(result) + # We have to sort the inner_stack but sorting is + # on average O(n log n). Particularly with the custom key + # `dependents_key`, this sorting operation can be quite expensive + # and dominate the entire ordering. + # There is also no guarantee that even if we sorted the entire + # stack, that we can actually process it until the end since there + # is logic that will switch the stack if a better target is found. + # Therefore, in case of large stacks, we break it up and take only + # the best nodes. This runs in linear time and will possibly allow + # us to release a couple of dangling runnables or find a better + # target before we come back to process the next batch + cutoff = 50 + if len(next_stack) > cutoff: + inner_stack = nsmallest(cutoff, list(next_stack), key=dependents_key)[ + ::-1 + ] + next_nodes[min_key].update(next_stack) + heappush(min_key_next_nodes, min_key) + else: + inner_stack = sorted(next_stack, key=dependents_key, reverse=True) + inner_stack_pop = inner_stack.pop + seen_update(inner_stack) continue - # We just finished computing a connected group. - # Let's choose the first `item` in the next group to compute. - # If we have few large groups left, then it's best to find `item` by taking a minimum. - # If we have many small groups left, then it's best to sort. - # If we have many tiny groups left, then it's best to simply iterate. + if inner_stack: + continue + + if len(result) == len(dsk): + break + if not is_init_sorted: - prev_len = len(init_stack) init_stack = set(init_stack) - init_stack = set_difference(init_stack, result) - N = len(init_stack) - m = prev_len - N - # is `min` likely better than `sort`? - if m >= N or N + (N - m) * log(N - m) < N * log(N): - item = min(init_stack, key=initial_stack_key) - continue - + init_stack = init_stack.difference(result) if len(init_stack) < 10000: init_stack = sorted(init_stack, key=initial_stack_key, reverse=True) else: init_stack = list(init_stack) - init_stack_pop = init_stack.pop is_init_sorted = True - if item in root_nodes: - item = init_stack_pop() - - while item in result: - item = init_stack_pop() - inner_stack.append(item) - + inner_stack = [init_stack.pop()] # type: ignore[call-overload] + inner_stack_pop = inner_stack.pop return result def graph_metrics( dependencies: Mapping[Key, set[Key]], dependents: Mapping[Key, set[Key]], - total_dependencies: Mapping[Key, int], -) -> dict[Key, tuple[int, int, int, int, int]]: +) -> dict[Key, tuple[int, int, int]]: r"""Useful measures of a graph used by ``dask.order.order`` Example DAG (a1 has no dependencies; b2 and c1 are root nodes): @@ -791,29 +435,7 @@ def graph_metrics( \ / 4 - 2. **min_dependencies**: The minimum value of the total number of - dependencies of all final dependents (see module-level comment for more). - In other words, the minimum of ``ndependencies`` of root - nodes connected to the current node. - - 3 - | - 3 2 - \ / - 2 - - 3. **max_dependencies**: The maximum value of the total number of - dependencies of all final dependents (see module-level comment for more). - In other words, the maximum of ``ndependencies`` of root - nodes connected to the current node. - - 3 - | - 3 2 - \ / - 3 - - 4. **min_height**: The minimum height from a root node + 2. **min_height**: The minimum height from a root node 0 | @@ -821,7 +443,7 @@ def graph_metrics( \ / 1 - 5. **max_height**: The maximum height from a root node + 3. **max_height**: The maximum height from a root node 0 | @@ -834,10 +456,9 @@ def graph_metrics( >>> inc = lambda x: x + 1 >>> dsk = {'a1': 1, 'b1': (inc, 'a1'), 'b2': (inc, 'a1'), 'c1': (inc, 'b1')} >>> dependencies, dependents = get_deps(dsk) - >>> _, total_dependencies = ndependencies(dependencies, dependents) - >>> metrics = graph_metrics(dependencies, dependents, total_dependencies) + >>> metrics = graph_metrics(dependencies, dependents) >>> sorted(metrics.items()) - [('a1', (4, 2, 3, 1, 2)), ('b1', (2, 3, 3, 1, 1)), ('b2', (1, 2, 2, 0, 0)), ('c1', (1, 3, 3, 0, 0))] + [('a1', (4, 1, 2)), ('b1', (2, 1, 1)), ('b2', (1, 0, 0)), ('c1', (1, 0, 0))] Returns ------- @@ -850,8 +471,7 @@ def graph_metrics( current_append = current.append for key, deps in dependents.items(): if not deps: - val = total_dependencies[key] - result[key] = (1, val, val, 0, 0) + result[key] = (1, 0, 0) for child in dependencies[key]: num_needed[child] -= 1 if not num_needed[child]: @@ -864,30 +484,22 @@ def graph_metrics( (parent,) = parents ( total_dependents, - min_dependencies, - max_dependencies, min_heights, max_heights, ) = result[parent] result[key] = ( 1 + total_dependents, - min_dependencies, - max_dependencies, 1 + min_heights, 1 + max_heights, ) else: ( total_dependents_, - min_dependencies_, - max_dependencies_, min_heights_, max_heights_, ) = zip(*(result[parent] for parent in dependents[key])) result[key] = ( 1 + sum(total_dependents_), - min(min_dependencies_), - max(max_dependencies_), 1 + min(min_heights_), 1 + max(max_heights_), ) @@ -1068,8 +680,10 @@ def _convert_task(task: Any) -> Any: elif isinstance(el, list): new_spec.append([_convert_task(e) for e in el]) return (_f, *new_spec) + elif isinstance(task, tuple): + return (_f, task) else: - return task + return (_f, *task) def sanitize_dsk(dsk: MutableMapping[Key, Any]) -> dict: diff --git a/dask/tests/test_order.py b/dask/tests/test_order.py index de43cc14a37..6c85fcc4bb7 100644 --- a/dask/tests/test_order.py +++ b/dask/tests/test_order.py @@ -3,21 +3,23 @@ import pytest import dask -from dask.base import collections_to_dsk +from dask import delayed +from dask.base import collections_to_dsk, key_split from dask.core import get_deps from dask.order import diagnostics, ndependencies, order from dask.utils_test import add, inc -@pytest.fixture(params=["abcde", "edcba"]) +@pytest.fixture( + params=[ + "abcde", + "edcba", + ] +) def abcde(request): return request.param -def issorted(L, reverse=False): - return sorted(L, reverse=reverse) == L - - def f(*args): pass @@ -589,8 +591,14 @@ def test_dont_run_all_dependents_too_early(abcde): dsk[(c, i)] = (f, (c, 0)) dsk[(d, i)] = (f, (d, i - 1), (b, i), (c, i)) o = order(dsk) - - expected = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30] + # The expected numbers here stand for the central computation branch. + # Ideally, they are exactly three apart since we only want to load two nodes + # before moving forward with a central node. + # This pattern is currently broken for the next to last node since we're + # eagerly freeing up the root node before processing once there are no + # further dangling dependents. Earlier versions of dask.order behaved + # different but this is a net-zero memory pressure operation so it is fine. + expected = [3, 6, 9, 12, 15, 18, 21, 24, 28, 30] actual = sorted(v for (letter, num), v in o.items() if letter == d) assert expected == actual @@ -653,61 +661,6 @@ def test_order_empty(): assert order({}) == {} -def test_switching_dependents(abcde): - r""" - - a7 a8 <-- do these last - | / - a6 e6 - | / - a5 c5 d5 e5 - | | / / - a4 c4 d4 e4 - | \ | / / - a3 b3---/ - | - a2 - | - a1 - | - a0 <-- start here - - Test that we are able to switch to better dependents. - In this graph, we expect to start at a0. To compute a4, we need to compute b3. - After computing b3, three "better" paths become available. - Confirm that we take the better paths before continuing down `a` path. - - This test is pretty specific to how `order` is implemented - and is intended to increase code coverage. - """ - a, b, c, d, e = abcde - dsk = { - (a, 0): 0, - (a, 1): (f, (a, 0)), - (a, 2): (f, (a, 1)), - (a, 3): (f, (a, 2)), - (a, 4): (f, (a, 3), (b, 3)), - (a, 5): (f, (a, 4)), - (a, 6): (f, (a, 5)), - (a, 7): (f, (a, 6)), - (a, 8): (f, (a, 6)), - (b, 3): 1, - (c, 4): (f, (b, 3)), - (c, 5): (f, (c, 4)), - (d, 4): (f, (b, 3)), - (d, 5): (f, (d, 4)), - (e, 4): (f, (b, 3)), - (e, 5): (f, (e, 4)), - (e, 6): (f, (e, 5)), - } - o = order(dsk) - - assert o[(a, 0)] == 0 # probably - assert o[(a, 5)] > o[(c, 5)] - assert o[(a, 5)] > o[(d, 5)] - assert o[(a, 5)] > o[(e, 6)] - - def test_order_with_equal_dependents(abcde): """From https://github.com/dask/dask/issues/5859#issuecomment-608422198 @@ -719,7 +672,7 @@ def test_order_with_equal_dependents(abcde): # Lower pressure is better but this is where we are right now. Important is # that no variation below should be worse since all variations below should # reduce to the same graph when optimized/fused. - max_pressure = 11 + max_pressure = 10 a, b, c, d, e = abcde dsk = {} abc = [a, b, c, d] @@ -747,13 +700,12 @@ def test_order_with_equal_dependents(abcde): total = 0 for x in abc: for i in range(len(abc)): - val = o[(x, 6, i, 1)] - o[(x, 6, i, 0)] - assert val > 0 # ideally, val == 2 + val = abs(o[(x, 6, i, 1)] - o[(x, 6, i, 0)]) total += val - assert total <= 56 # ideally, this should be 2 * 16 == 32 + + assert total <= 32 # ideally, this should be 2 * 16 == 32 pressure = diagnostics(dsk, o=o)[1] assert max(pressure) <= max_pressure - # Add one to the end of the nine bundles dsk2 = dict(dsk) for x in abc: @@ -763,10 +715,9 @@ def test_order_with_equal_dependents(abcde): total = 0 for x in abc: for i in range(len(abc)): - val = o[(x, 6, i, 1)] - o[(x, 7, i, 0)] - assert val > 0 # ideally, val == 3 + val = abs(o[(x, 6, i, 1)] - o[(x, 7, i, 0)]) total += val - assert total <= 75 # ideally, this should be 3 * 16 == 48 + assert total <= 48 # ideally, this should be 3 * 16 == 48 pressure = diagnostics(dsk2, o=o)[1] assert max(pressure) <= max_pressure @@ -779,14 +730,13 @@ def test_order_with_equal_dependents(abcde): total = 0 for x in abc: for i in range(len(abc)): - val = o[(x, 5, i, 1)] - o[(x, 6, i, 0)] - assert val > 0 + val = abs(o[(x, 5, i, 1)] - o[(x, 6, i, 0)]) total += val - assert total <= 45 # ideally, this should be 2 * 16 == 32 + assert total <= 32 # ideally, this should be 2 * 16 == 32 pressure = diagnostics(dsk3, o=o)[1] assert max(pressure) <= max_pressure - # Remove another one from each of the nine bundles + # # Remove another one from each of the nine bundles dsk4 = dict(dsk3) for x in abc: for i in range(len(abc)): @@ -796,7 +746,7 @@ def test_order_with_equal_dependents(abcde): assert max(pressure) <= max_pressure for x in abc: for i in range(len(abc)): - assert abs(o[(x, 5, i, 1)] - o[(x, 5, i, 0)]) <= 10 + assert abs(o[(x, 5, i, 1)] - o[(x, 5, i, 0)]) <= 2 def test_terminal_node_backtrack(): @@ -872,7 +822,6 @@ def test_array_store_final_order(tmpdir): dest = root.empty_like(name="dest", data=x, chunks=x.chunksize, overwrite=True) d = x.store(dest, lock=False, compute=False) o = order(d.dask) - # Find the lowest store. Dask starts here. stores = [k for k in o if isinstance(k, tuple) and k[0].startswith("store-map-")] first_store = min(stores, key=lambda k: o[k]) @@ -1010,6 +959,7 @@ def test_diagnostics(abcde): (e, 1): (f, (e, 0)), } o = order(dsk) + assert o[(e, 1)] == len(dsk) - 1 assert o[(d, 1)] == len(dsk) - 2 assert o[(c, 1)] == len(dsk) - 3 @@ -1214,7 +1164,6 @@ def test_anom_mean_raw(): } o = order(dsk) - # The left hand computation branch should complete before we start loading # more data nodes_to_finish_before_loading_more_data = [ @@ -1668,3 +1617,43 @@ def test_flox_reduction(): } o = order(dsk) assert max(o[("F1", ix)] for ix in range(3)) < min(o[("F2", ix)] for ix in range(3)) + + +@pytest.mark.parametrize("ndeps", [2, 5]) +@pytest.mark.parametrize("n_reducers", [4, 7]) +def test_reduce_with_many_common_dependents(ndeps, n_reducers): + da = pytest.importorskip("dask.array") + import numpy as np + + def random(**kwargs): + assert len(kwargs) == ndeps + return np.random.random((10, 10)) + + trivial_deps = { + f"k{i}": delayed(object(), name=f"object-{i}") for i in range(ndeps) + } + x = da.blockwise( + random, + "yx", + new_axes={"y": (10,) * n_reducers, "x": (10,) * n_reducers}, + dtype=float, + **trivial_deps, + ) + graph = x.sum(axis=1, split_every=20) + from dask.order import order + + dsk = collections_to_dsk([graph]) + dependencies, dependents = get_deps(dsk) + # Verify assumptions + o = order(dsk) + # Verify assumptions (specifically that the reducers are sum-aggregate) + assert {key_split(k) for k in o} == {"object", "sum", "sum-aggregate"} + + reducers = {k for k in o if key_split(k) == "sum-aggregate"} + drift = dict() + for r in reducers: + prios_deps = [] + for dep in dependencies[r]: + prios_deps.append(o[dep]) + drift[r] = (min(prios_deps), max(prios_deps)) + assert max(prios_deps) - min(prios_deps) == len(dependencies[r]) - 1