diff --git a/dask/__init__.py b/dask/__init__.py index 41f65daa415..aec1a776f62 100644 --- a/dask/__init__.py +++ b/dask/__init__.py @@ -1,4 +1,5 @@ from . import config, datasets + from .core import istask from .local import get_sync as get @@ -7,7 +8,14 @@ except ImportError: pass try: - from .base import visualize, compute, persist, optimize, is_dask_collection + from .base import ( + visualize, + annotate, + compute, + persist, + optimize, + is_dask_collection, + ) except ImportError: pass diff --git a/dask/base.py b/dask/base.py index 49961d70537..8c05d575c64 100644 --- a/dask/base.py +++ b/dask/base.py @@ -1,5 +1,6 @@ from collections import OrderedDict from collections.abc import Mapping, Iterator +from contextlib import contextmanager from functools import partial from hashlib import md5 from operator import getitem @@ -23,6 +24,7 @@ __all__ = ( "DaskMethodsMixin", + "annotate", "is_dask_collection", "compute", "persist", @@ -33,6 +35,52 @@ ) +@contextmanager +def annotate(**annotations): + """Content Manager for setting HighLevelGraph Layer annotations. + + Annotations are metadata or soft constraints associated with + tasks that dask schedulers may choose to respect: They signal intent + without enforcing hard constraints. As such, they are + primarily designed for use with the distributed scheduler. + + Almost any object can serve as an annotation, but small Python objects + are preferred, while large objects such as NumPy arrays should be discouraged. + + Callables supplied as an annotation should take a single *key* argument and + produce the appropriate annotation. Individual task keys in the annotated collection + are supplied to the callable. + + Parameters + ---------- + **annotations : key-value pairs + + Examples + -------- + + All tasks within array A should have priority 100 and be retried 3 times + on failure. + + >>> with dask.annotate(priority=100, retries=3): + A = da.ones((10000, 10000)) + + Prioritise tasks within Array A on flattened block ID. + + >>> nblocks = (10, 10) + >>> with dask.annotate(priority=lambda k: k[1]*nblocks[1] + k[2]): + A = da.ones((1000, 1000), chunks=(100, 100)) + """ + + prev_annotations = config.get("annotations", {}) + new_annotations = { + **prev_annotations, + **{f"annotations.{k}": v for k, v in annotations.items()}, + } + + with config.set(new_annotations): + yield + + def is_dask_collection(x): """Returns ``True`` if ``x`` is a dask collection""" try: @@ -95,7 +143,7 @@ def visualize(self, filename="mydask", format=None, optimize_graph=False, **kwar filename=filename, format=format, optimize_graph=optimize_graph, - **kwargs + **kwargs, ) def persist(self, **kwargs): diff --git a/dask/blockwise.py b/dask/blockwise.py index 10c10083ab3..44c2ca3546c 100644 --- a/dask/blockwise.py +++ b/dask/blockwise.py @@ -210,6 +210,7 @@ def __init__( new_axes=None, io_subgraph=None, ): + super().__init__() self.output = output self.output_indices = tuple(output_indices) self.io_subgraph = io_subgraph[1] if io_subgraph else None diff --git a/dask/highlevelgraph.py b/dask/highlevelgraph.py index 1d3da3906fe..9dc2a6e3053 100644 --- a/dask/highlevelgraph.py +++ b/dask/highlevelgraph.py @@ -15,6 +15,7 @@ import tlz as toolz +from . import config from .utils import ignoring from .base import is_dask_collection from .core import reverse_dict, keys_in_tasks @@ -54,6 +55,9 @@ class Layer(collections.abc.Mapping): implementations. """ + def __init__(self): + self.annotations = copy.copy(config.get("annotations", None)) + @abc.abstractmethod def is_materialized(self) -> bool: """Return whether the layer is materialized or not""" @@ -246,6 +250,7 @@ class BasicLayer(Layer): """ def __init__(self, mapping, dependencies=None, global_dependencies=None): + super().__init__() self.mapping = mapping self.dependencies = dependencies self.global_dependencies = global_dependencies diff --git a/dask/tests/test_highgraph.py b/dask/tests/test_highgraph.py index 504ccd82c10..08fa457de4b 100644 --- a/dask/tests/test_highgraph.py +++ b/dask/tests/test_highgraph.py @@ -3,6 +3,7 @@ import pytest +import dask import dask.array as da from dask.utils_test import inc from dask.highlevelgraph import HighLevelGraph, BasicLayer, Layer @@ -110,3 +111,42 @@ def plus_one(tasks): y.dask = dsk.map_tasks(plus_one) assert_eq(y, [42] * 3) + + +def annot_map_fn(key): + return key[1:] + + +@pytest.mark.parametrize( + "annotation", + [ + {"worker": "alice"}, + {"block_id": annot_map_fn}, + ], +) +def test_single_annotation(annotation): + with dask.annotate(**annotation): + A = da.ones((10, 10), chunks=(5, 5)) + + alayer = A.__dask_graph__().layers[A.name] + assert alayer.annotations == annotation + assert dask.config.get("annotations", None) is None + + +def test_multiple_annotations(): + with dask.annotate(block_id=annot_map_fn): + with dask.annotate(resource="GPU"): + A = da.ones((10, 10), chunks=(5, 5)) + + B = A + 1 + + C = B + 1 + + assert dask.config.get("annotations", None) is None + + alayer = A.__dask_graph__().layers[A.name] + blayer = B.__dask_graph__().layers[B.name] + clayer = C.__dask_graph__().layers[C.name] + assert alayer.annotations == {"resource": "GPU", "block_id": annot_map_fn} + assert blayer.annotations == {"block_id": annot_map_fn} + assert clayer.annotations is None diff --git a/docs/source/api.rst b/docs/source/api.rst index af0a6cce425..c4be7735f8a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -46,6 +46,7 @@ real-time or advanced operation. This more advanced API is available in the `Dask distributed documentation `_ +.. autofunction:: annotate .. autofunction:: compute .. autofunction:: is_dask_collection .. autofunction:: optimize