99
1010import tlz as toolz
1111
12+ import dask
1213from dask .base import clone_key , get_name_from_key , tokenize
1314from dask .core import flatten , keys_in_tasks , reverse_dict
1415from dask .delayed import unpack_collections
@@ -1351,10 +1352,9 @@ def _optimize_blockwise(full_graph, keys=()):
13511352 ):
13521353 stack .append (dep )
13531354 continue
1354- if (
1355- blockwise_layers
1356- and layers [next (iter (blockwise_layers ))].annotations
1357- != layers [dep ].annotations
1355+ if blockwise_layers and not _can_fuse_annotations (
1356+ layers [next (iter (blockwise_layers ))].annotations ,
1357+ layers [dep ].annotations ,
13581358 ):
13591359 stack .append (dep )
13601360 continue
@@ -1412,6 +1412,60 @@ def _unique_dep(dep, ind):
14121412 return dep + "_" + "_" .join (str (i ) for i in list (ind ))
14131413
14141414
1415+ def _can_fuse_annotations (a : dict | None , b : dict | None ) -> bool :
1416+ """
1417+ Treat the special annotation keys, as fusable since we can apply simple
1418+ rules to capture their intent in a fused layer.
1419+ """
1420+ if a == b :
1421+ return True
1422+
1423+ if dask .config .get ("optimization.annotations.fuse" ) is False :
1424+ return False
1425+
1426+ fusable = {"retries" , "priority" , "resources" , "workers" , "allow_other_workers" }
1427+ if (not a or all (k in fusable for k in a )) and (
1428+ not b or all (k in fusable for k in b )
1429+ ):
1430+ return True
1431+
1432+ return False
1433+
1434+
1435+ def _fuse_annotations (* args : dict ) -> dict :
1436+ """
1437+ Given an iterable of annotations dictionaries, fuse them according
1438+ to some simple rules.
1439+ """
1440+ # First, do a basic dict merge -- we are presuming that these have already
1441+ # been gated by `_can_fuse_annotations`.
1442+ annotations = toolz .merge (* args )
1443+ # Max of layer retries
1444+ retries = [a ["retries" ] for a in args if "retries" in a ]
1445+ if retries :
1446+ annotations ["retries" ] = max (retries )
1447+ # Max of layer priorities
1448+ priorities = [a ["priority" ] for a in args if "priority" in a ]
1449+ if priorities :
1450+ annotations ["priority" ] = max (priorities )
1451+ # Max of all the layer resources
1452+ resources = [a ["resources" ] for a in args if "resources" in a ]
1453+ if resources :
1454+ annotations ["resources" ] = toolz .merge_with (max , * resources )
1455+ # Intersection of all the worker restrictions
1456+ workers = [a ["workers" ] for a in args if "workers" in a ]
1457+ if workers :
1458+ annotations ["workers" ] = list (set .intersection (* [set (w ) for w in workers ]))
1459+ # More restrictive of allow_other_workers
1460+ allow_other_workers = [
1461+ a ["allow_other_workers" ] for a in args if "allow_other_workers" in a
1462+ ]
1463+ if allow_other_workers :
1464+ annotations ["allow_other_workers" ] = all (allow_other_workers )
1465+
1466+ return annotations
1467+
1468+
14151469def rewrite_blockwise (inputs ):
14161470 """Rewrite a stack of Blockwise expressions into a single blockwise expression
14171471
@@ -1435,6 +1489,9 @@ def rewrite_blockwise(inputs):
14351489 # Fast path: if there's only one input we can just use it as-is.
14361490 return inputs [0 ]
14371491
1492+ fused_annotations = _fuse_annotations (
1493+ * [i .annotations for i in inputs if i .annotations ]
1494+ )
14381495 inputs = {inp .output : inp for inp in inputs }
14391496 dependencies = {
14401497 inp .output : {d for d , v in inp .indices if v is not None and d in inputs }
@@ -1560,7 +1617,7 @@ def rewrite_blockwise(inputs):
15601617 numblocks = numblocks ,
15611618 new_axes = new_axes ,
15621619 concatenate = concatenate ,
1563- annotations = inputs [ root ]. annotations ,
1620+ annotations = fused_annotations ,
15641621 io_deps = io_deps ,
15651622 )
15661623
0 commit comments