diff --git a/dask/array/optimization.py b/dask/array/optimization.py index a6ae0b985fe..ff113a2fa71 100644 --- a/dask/array/optimization.py +++ b/dask/array/optimization.py @@ -6,7 +6,7 @@ from .core import getter, getter_nofancy, getter_inline from ..blockwise import optimize_blockwise, fuse_roots from ..core import flatten, reverse_dict -from ..optimization import cull, fuse, inline_functions +from ..optimization import fuse, inline_functions from ..utils import ensure_dict from ..highlevelgraph import HighLevelGraph @@ -35,39 +35,40 @@ def optimize( 2. Remove full slicing, e.g. x[:] 3. Inline fast functions like getitem and np.transpose """ + if not isinstance(keys, (list, set)): + keys = [keys] keys = list(flatten(keys)) - # High level stage optimization - if isinstance(dsk, HighLevelGraph): - dsk = optimize_blockwise(dsk, keys=keys) - dsk = fuse_roots(dsk, keys=keys) + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) - # Low level task optimizations + dsk = optimize_blockwise(dsk, keys=keys) + dsk = fuse_roots(dsk, keys=keys) + dsk = dsk.cull(set(keys)) + dependencies = dsk.get_dependencies() dsk = ensure_dict(dsk) + + # Low level task optimizations if fast_functions is not None: inline_functions_fast_functions = fast_functions - dsk2, dependencies = cull(dsk, keys) - hold = hold_keys(dsk2, dependencies) + hold = hold_keys(dsk, dependencies) - dsk3, dependencies = fuse( - dsk2, + dsk, dependencies = fuse( + dsk, hold + keys + (fuse_keys or []), dependencies, rename_keys=rename_fused_keys, ) if inline_functions_fast_functions: - dsk4 = inline_functions( - dsk3, + dsk = inline_functions( + dsk, keys, dependencies=dependencies, fast_functions=inline_functions_fast_functions, ) - else: - dsk4 = dsk3 - dsk5 = optimize_slices(dsk4) - return dsk5 + return optimize_slices(dsk) def hold_keys(dsk, dependencies): diff --git a/dask/blockwise.py b/dask/blockwise.py index b634563d06d..5d1ad8c2fb8 100644 --- a/dask/blockwise.py +++ b/dask/blockwise.py @@ -1,14 +1,13 @@ import itertools import warnings -from collections.abc import Mapping import numpy as np import tlz as toolz -from .core import reverse_dict, keys_in_tasks +from .core import reverse_dict, flatten, keys_in_tasks, find_all_possible_keys from .delayed import unpack_collections -from .highlevelgraph import HighLevelGraph +from .highlevelgraph import BasicLayer, HighLevelGraph, Layer from .optimization import SubgraphCallable, fuse from .utils import ensure_dict, homogeneous_deepmap, apply @@ -132,7 +131,7 @@ def blockwise( return subgraph -class Blockwise(Mapping): +class Blockwise(Layer): """Tensor Operation This is a lazily constructed mapping for tensor operation graphs. @@ -206,21 +205,30 @@ def __repr__(self): @property def _dict(self): if hasattr(self, "_cached_dict"): - return self._cached_dict + return self._cached_dict["dsk"] else: keys = tuple(map(blockwise_token, range(len(self.indices)))) dsk, _ = fuse(self.dsk, [self.output]) func = SubgraphCallable(dsk, self.output, keys) - self._cached_dict = make_blockwise_graph( + + key_deps = {} + non_blockwise_keys = set() + dsk = make_blockwise_graph( func, self.output, self.output_indices, *list(toolz.concat(self.indices)), new_axes=self.new_axes, numblocks=self.numblocks, - concatenate=self.concatenate + concatenate=self.concatenate, + key_deps=key_deps, + non_blockwise_keys=non_blockwise_keys, ) - return self._cached_dict + self._cached_dict = { + "dsk": dsk, + "basic_layer": BasicLayer(dsk, key_deps, non_blockwise_keys), + } + return self._cached_dict["dsk"] def __getitem__(self, key): return self._dict[key] @@ -243,6 +251,14 @@ def _out_numblocks(self): return out_d + def get_dependencies(self, all_hlg_keys): + _ = self._dict # trigger materialization + return self._cached_dict["basic_layer"].get_dependencies(all_hlg_keys) + + def cull(self, keys, all_hlg_keys): + _ = self._dict # trigger materialization + return self._cached_dict["basic_layer"].cull(keys, all_hlg_keys) + def make_blockwise_graph(func, output, out_indices, *arrind_pairs, **kwargs): """Tensor operation @@ -353,6 +369,8 @@ def make_blockwise_graph(func, output, out_indices, *arrind_pairs, **kwargs): numblocks = kwargs.pop("numblocks") concatenate = kwargs.pop("concatenate", None) new_axes = kwargs.pop("new_axes", {}) + key_deps = kwargs.pop("key_deps", None) + non_blockwise_keys = kwargs.pop("non_blockwise_keys", None) argpairs = list(toolz.partition(2, arrind_pairs)) if concatenate is True: @@ -417,6 +435,7 @@ def make_blockwise_graph(func, output, out_indices, *arrind_pairs, **kwargs): else: coord_maps.append(None) concat_axes.append(None) + # Unpack delayed objects in kwargs dsk2 = {} if kwargs: @@ -425,33 +444,47 @@ def make_blockwise_graph(func, output, out_indices, *arrind_pairs, **kwargs): kwargs2 = task else: kwargs2 = kwargs + if non_blockwise_keys is not None: + non_blockwise_keys |= find_all_possible_keys([kwargs2]) + + # Find all non-blockwise keys in the input arguments + if non_blockwise_keys is not None: + for arg, ind in argpairs: + if ind is None: + non_blockwise_keys |= find_all_possible_keys([arg]) dsk = {} # Create argument lists for out_coords in itertools.product(*[range(dims[i]) for i in out_indices]): + deps = set() coords = out_coords + dummies args = [] - for cmap, axes, arg_ind in zip(coord_maps, concat_axes, argpairs): - arg, ind = arg_ind + for cmap, axes, (arg, ind) in zip(coord_maps, concat_axes, argpairs): if ind is None: args.append(arg) else: arg_coords = tuple(coords[c] for c in cmap) if axes: tups = lol_product((arg,), arg_coords) + deps.update(flatten(tups)) if concatenate: tups = (concatenate, tups, axes) else: tups = (arg,) + arg_coords + deps.add(tups) args.append(tups) + out_key = (output,) + out_coords + if kwargs: val = (apply, func, args, kwargs2) else: args.insert(0, func) val = tuple(args) - dsk[(output,) + out_coords] = val + dsk[out_key] = val + if key_deps is not None: + key_deps[out_key] = deps if dsk2: dsk.update(ensure_dict(dsk2)) diff --git a/dask/core.py b/dask/core.py index 533bba86d03..8fb998afeb4 100644 --- a/dask/core.py +++ b/dask/core.py @@ -191,6 +191,34 @@ def keys_in_tasks(keys, tasks, as_list=False): return ret if as_list else set(ret) +def find_all_possible_keys(tasks) -> set: + """Returns all possible keys in `tasks` including hashable literals. + + The definition of a key in a Dask graph is any hashable object + that is not a task. This function returns all such objects in + `tasks` even if the object is in fact a literal. + + """ + ret = set() + while tasks: + work = [] + for w in tasks: + typ = type(w) + if typ is tuple and w and callable(w[0]): # istask(w) + work.extend(w[1:]) + elif typ is list: + work.extend(w) + elif typ is dict: + work.extend(w.values()) + else: + try: + ret.add(w) + except TypeError: # not hashable + pass + tasks = work + return ret + + def get_dependencies(dsk, key=None, task=no_default, as_list=False): """Get the immediate tasks on which this task depends diff --git a/dask/dataframe/io/parquet/core.py b/dask/dataframe/io/parquet/core.py index faefcbc6806..e151fa361ab 100644 --- a/dask/dataframe/io/parquet/core.py +++ b/dask/dataframe/io/parquet/core.py @@ -9,8 +9,8 @@ from ...core import DataFrame, new_dd_object from ....base import tokenize from ....utils import import_required, natural_sort_key, parse_bytes -from collections.abc import Mapping from ...methods import concat +from ....highlevelgraph import Layer try: @@ -29,14 +29,16 @@ # User API -class ParquetSubgraph(Mapping): +class ParquetSubgraph(Layer): """ Subgraph for reading Parquet files. Enables optimizations (see optimize_read_parquet_getitem). """ - def __init__(self, name, engine, fs, meta, columns, index, parts, kwargs): + def __init__( + self, name, engine, fs, meta, columns, index, parts, kwargs, part_ids=None + ): self.name = name self.engine = engine self.fs = fs @@ -45,10 +47,11 @@ def __init__(self, name, engine, fs, meta, columns, index, parts, kwargs): self.index = index self.parts = parts self.kwargs = kwargs + self.part_ids = list(range(len(parts))) if part_ids is None else part_ids def __repr__(self): return "ParquetSubgraph".format( - self.name, len(self.parts), list(self.columns) + self.name, len(self.part_ids), list(self.columns) ) def __getitem__(self, key): @@ -61,7 +64,7 @@ def __getitem__(self, key): if name != self.name: raise KeyError(key) - if i < 0 or i >= len(self.parts): + if i not in self.part_ids: raise KeyError(key) part = self.parts[i] @@ -80,12 +83,29 @@ def __getitem__(self, key): ) def __len__(self): - return len(self.parts) + return len(self.part_ids) def __iter__(self): - for i in range(len(self)): + for i in self.part_ids: yield (self.name, i) + def get_dependencies(self, all_hlg_keys): + return {k: set() for k in self} + + def cull(self, keys, all_hlg_keys): + ret = ParquetSubgraph( + name=self.name, + engine=self.engine, + fs=self.fs, + meta=self.meta, + columns=self.columns, + index=self.index, + parts=self.parts, + kwargs=self.kwargs, + part_ids={i for i in self.part_ids if (self.name, i) in keys}, + ) + return ret, ret.get_dependencies(all_hlg_keys) + def read_parquet( path, diff --git a/dask/dataframe/optimize.py b/dask/dataframe/optimize.py index 6cae4688d86..e66dfca7640 100644 --- a/dask/dataframe/optimize.py +++ b/dask/dataframe/optimize.py @@ -10,25 +10,28 @@ def optimize(dsk, keys, **kwargs): + if not isinstance(keys, (list, set)): + keys = [keys] + keys = list(core.flatten(keys)) - if isinstance(dsk, HighLevelGraph): - # Think about an API for this. - flat_keys = list(core.flatten(keys)) - dsk = optimize_read_parquet_getitem(dsk, keys=flat_keys) - dsk = optimize_blockwise(dsk, keys=flat_keys) - dsk = fuse_roots(dsk, keys=flat_keys) + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) - dsk = ensure_dict(dsk) + dsk = optimize_read_parquet_getitem(dsk, keys=keys) + dsk = optimize_blockwise(dsk, keys=keys) + dsk = fuse_roots(dsk, keys=keys) + dsk = dsk.cull(set(keys)) + + if not config.get("optimization.fuse.active"): + return dsk - if isinstance(keys, list): - dsk, dependencies = cull(dsk, list(core.flatten(keys))) - else: - dsk, dependencies = cull(dsk, [keys]) + dependencies = dsk.get_dependencies() + dsk = ensure_dict(dsk) fuse_subgraphs = config.get("optimization.fuse.subgraphs") if fuse_subgraphs is None: fuse_subgraphs = True - dsk, dependencies = fuse( + dsk, _ = fuse( dsk, keys, dependencies=dependencies, diff --git a/dask/highlevelgraph.py b/dask/highlevelgraph.py index a0f772699a2..5a492a06cce 100644 --- a/dask/highlevelgraph.py +++ b/dask/highlevelgraph.py @@ -1,10 +1,13 @@ -from collections.abc import Mapping +import collections.abc +from typing import Hashable, Optional, Set, Mapping, Iterable, Tuple +import copy import tlz as toolz from .utils import ignoring from .base import is_dask_collection from .core import reverse_dict, keys_in_tasks +from .utils_test import add, inc # noqa: F401 def compute_layer_dependencies(layers): @@ -24,6 +27,109 @@ def _find_layer_containing_key(key): return ret +class Layer(collections.abc.Mapping): + """High level graph layer + + This abstract class establish a protocol for high level graph layers. + """ + + def cull( + self, keys: Set, all_hlg_keys: Iterable + ) -> Tuple["Layer", Mapping[Hashable, Set]]: + """Return a new Layer with only the tasks required to calculate `keys`. + + In other words, remove unnecessary tasks from the layer. + + Examples + -------- + >>> d = Layer({'x': 1, 'y': (inc, 'x'), 'out': (add, 'x', 10)}) # doctest: +SKIP + >>> d.cull({'out'}) # doctest: +SKIP + {'x': 1, 'out': (add, 'x', 10)} + + Returns + ------- + layer: Layer + Culled layer + """ + deps = self.get_dependencies(all_hlg_keys) + + if len(keys) == len(self): + return self, deps # Nothing to cull if preserving all existing keys + + ret_deps = {} + seen = set() + out = {} + work = keys.copy() + while work: + k = work.pop() + out[k] = self[k] + ret_deps[k] = deps[k] + for d in deps[k]: + if d not in seen: + if d in self: + seen.add(d) + work.add(d) + + return BasicLayer(out), ret_deps + + def get_dependencies(self, all_hlg_keys: Iterable) -> Mapping[Hashable, Set]: + """Get dependencies of all keys in the layer + + Parameters + ---------- + all_hlg_keys : Iterable + All keys in the high level graph. + + Returns + ------- + map: Mapping + A map that maps each key in the layer to its dependencies + """ + return {k: keys_in_tasks(all_hlg_keys, [v]) for k, v in self.items()} + + +class BasicLayer(Layer): + """Basic implementation of `Layer` + + Parameters + ---------- + mapping : Mapping + The mapping between keys and tasks, typically a dask graph. + dependencies : Mapping[Hashable, Set], optional + Mapping between keys and their dependencies + global_dependencies: Set, optional + Set of dependencies that all keys in the layer depend on. Notice, + the set might also contain literals that will be ignored. + """ + + def __init__(self, mapping, dependencies=None, global_dependencies=None): + self.mapping = mapping + self.dependencies = dependencies + self.global_dependencies = global_dependencies + + def __contains__(self, k): + return k in self.mapping + + def __getitem__(self, k): + return self.mapping[k] + + def __iter__(self): + return iter(self.mapping) + + def __len__(self): + return len(self.mapping) + + def get_dependencies(self, all_hlg_keys): + if self.dependencies is None or self.global_dependencies is None: + return super().get_dependencies(all_hlg_keys) + + global_deps = self.global_dependencies.intersection(all_hlg_keys) + ret = self.dependencies.copy() + for v in ret.values(): + v |= global_deps + return ret + + class HighLevelGraph(Mapping): """Task graph composed of layers of dependent subgraphs @@ -88,9 +194,23 @@ class HighLevelGraph(Mapping): typically used by developers to make new HighLevelGraphs """ - def __init__(self, layers, dependencies): + def __init__( + self, + layers: Mapping[str, Mapping], + dependencies: Mapping[str, Set], + key_dependencies: Optional[Mapping[Hashable, Set]] = None, + ): + self.__keys = None self.layers = layers self.dependencies = dependencies + self.key_dependencies = key_dependencies + + def keyset(self): + if self.__keys is None: + self.__keys = set() + for layer in self.layers.values(): + self.__keys.update(layer.keys()) + return self.__keys @property def dependents(self): @@ -162,8 +282,7 @@ def from_collections(cls, name, layer, dependencies=()): if len(dependencies) == 1: return cls._from_collection(name, layer, dependencies[0]) layers = {name: layer} - deps = {} - deps[name] = set() + deps = {name: set()} for collection in toolz.unique(dependencies, key=id): if is_dask_collection(collection): graph = collection.__dask_graph__() @@ -192,7 +311,10 @@ def __getitem__(self, key): raise KeyError(key) def __len__(self): - return sum(1 for _ in self) + return len(self.keyset()) + + def __iter__(self): + return toolz.unique(toolz.concat(self.layers.values())) def items(self): items = [] @@ -204,15 +326,15 @@ def items(self): items.append((key, d[key])) return items - def __iter__(self): - return toolz.unique(toolz.concat(self.layers.values())) - def keys(self): return [key for key, _ in self.items()] def values(self): return [value for _, value in self.items()] + def copy(self): + return HighLevelGraph(self.layers.copy(), self.dependencies.copy()) + @classmethod def merge(cls, *graphs): layers = {} @@ -234,6 +356,97 @@ def visualize(self, filename="dask.pdf", format=None, **kwargs): g = to_graphviz(self, **kwargs) return graphviz_to_file(g, filename, format) + def get_dependencies(self) -> Mapping[Hashable, Set]: + """Get dependencies of all keys in the HLG + + Returns + ------- + map: Mapping + A map that maps each key to its dependencies + """ + if self.key_dependencies is None: + all_keys = self.keyset() + self.key_dependencies = {} + for layer in self.layers.values(): + self.key_dependencies.update(layer.get_dependencies(all_keys)) + + return self.key_dependencies + + def _fix_hlg_layers_inplace(self): + """Makes sure that all layers in hlg are `Layer`""" + new_layers = {} + for k, v in self.layers.items(): + if not isinstance(v, Layer): + new_layers[k] = BasicLayer(v) + self.layers.update(new_layers) + + def _toposort_layers(self): + """Sort the layers in a high level graph topologically + + Parameters + ---------- + hlg : HighLevelGraph + The high level graph's layers to sort + + Returns + ------- + sorted: list + List of layer names sorted topologically + """ + dependencies = copy.deepcopy(self.dependencies) + ready = {k for k, v in dependencies.items() if len(v) == 0} + ret = [] + while len(ready) > 0: + layer = ready.pop() + ret.append(layer) + del dependencies[layer] + for k, v in dependencies.items(): + v.discard(layer) + if len(v) == 0: + ready.add(k) + return ret + + def cull(self, keys: Set): + """Return new high level graph with only the tasks required to calculate keys. + + In other words, remove unnecessary tasks from dask. + ``keys`` may be a single key or list of keys. + + Returns + ------- + hlg: HighLevelGraph + Culled high level graph + """ + + self._fix_hlg_layers_inplace() + layers = self._toposort_layers() + all_keys = self.keyset() + + ret_layers = {} + ret_key_deps = {} + for layer_name in reversed(layers): + layer = self.layers[layer_name] + key_deps = keys.intersection(layer) + if len(key_deps) > 0: + culled_layer, culled_deps = layer.cull(key_deps, all_keys) + + external_deps = set() + for k in culled_layer.keys(): + external_deps |= culled_deps[k] + external_deps.difference_update(culled_layer.keys()) + + keys.update(external_deps) + ret_layers[layer_name] = culled_layer + ret_key_deps.update(culled_deps) + + ret_dependencies = {} + for layer_name in ret_layers: + ret_dependencies[layer_name] = { + d for d in self.dependencies[layer_name] if d in ret_layers + } + + return HighLevelGraph(ret_layers, ret_dependencies, ret_key_deps) + def validate(self): # Check dependencies for layer_name, deps in self.dependencies.items(): diff --git a/dask/tests/test_highgraph.py b/dask/tests/test_highgraph.py index 33027c4d128..0c04bf1bec5 100644 --- a/dask/tests/test_highgraph.py +++ b/dask/tests/test_highgraph.py @@ -4,7 +4,7 @@ import dask.array as da from dask.utils_test import inc -from dask.highlevelgraph import HighLevelGraph +from dask.highlevelgraph import HighLevelGraph, BasicLayer def test_visualize(tmpdir): @@ -40,3 +40,20 @@ def test_keys_values_items_methods(): assert keys == [i for i in hg] assert values == [hg[i] for i in hg] assert items == [(k, v) for k, v in zip(keys, values)] + + +def test_cull(): + a = {"x": 1, "y": (inc, "x")} + layers = { + "a": BasicLayer( + a, dependencies={"x": set(), "y": {"x"}}, global_dependencies=set() + ) + } + dependencies = {"a": set()} + hg = HighLevelGraph(layers, dependencies) + + culled_by_x = hg.cull({"x"}) + assert dict(culled_by_x) == {"x": 1} + + culled_by_y = hg.cull({"y"}) + assert dict(culled_by_y) == a