Skip to content

Commit aaf3f1f

Browse files
committed
Prototype for assignment groups
1 parent f297ad4 commit aaf3f1f

3 files changed

Lines changed: 113 additions & 14 deletions

File tree

dask/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,9 +719,10 @@ def visualize(
719719
720720
https://docs.dask.org/en/latest/optimize.html
721721
"""
722-
args, _ = unpack_collections(*args, traverse=traverse)
722+
dsk = args[0]
723+
# args, _ = unpack_collections(*args, traverse=traverse)
723724

724-
dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph))
725+
# dsk = dict(collections_to_dsk(args, optimize_graph=optimize_graph))
725726

726727
color = kwargs.get("color")
727728

@@ -737,13 +738,17 @@ def visualize(
737738
"memoryincreases",
738739
"memorydecreases",
739740
"memorypressure",
741+
"group",
742+
"order-group",
740743
}:
741744
import matplotlib.pyplot as plt
742745

743746
from dask.order import diagnostics, order
744747

745-
if o is None:
746-
o = order(dsk)
748+
if "group" in color:
749+
o, groups = order(dsk, group=True)
750+
elif o is None:
751+
o = order(dsk, group=False)
747752
try:
748753
cmap = kwargs.pop("cmap")
749754
except KeyError:
@@ -773,6 +778,10 @@ def label(x):
773778
key: max(0, val.num_data_when_released - val.num_data_when_run)
774779
for key, val in info.items()
775780
}
781+
elif color.endswith("group"):
782+
values = {
783+
key: group_ix for group_ix, keys in groups.items() for key in keys
784+
}
776785
else: # memorydecreases
777786
values = {
778787
key: max(0, val.num_data_when_run - val.num_data_when_released)

dask/order.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,40 @@
8080
from collections import defaultdict, namedtuple
8181
from collections.abc import Mapping, MutableMapping
8282
from heapq import heappop, heappush
83-
from typing import Any, cast
83+
from typing import Any, Literal, cast, overload
8484

8585
from dask.core import get_dependencies, get_deps, getcycle, istask, reverse_dict
8686
from dask.typing import Key
8787

8888

89+
@overload
8990
def order(
9091
dsk: MutableMapping[Key, Any],
91-
dependencies: MutableMapping[Key, set[Key]] | None = None,
92+
dependencies: MutableMapping[Key, set[Key]] | None,
93+
group: Literal[False],
9294
) -> dict[Key, int]:
95+
...
96+
97+
98+
@overload
99+
def order(
100+
dsk: MutableMapping[Key, Any],
101+
dependencies: MutableMapping[Key, set[Key]] | None,
102+
group: Literal[True],
103+
) -> tuple[dict[Key, int], dict[int, list[Key]]]:
104+
...
105+
106+
107+
def order(
108+
dsk: MutableMapping[Key, Any],
109+
dependencies: MutableMapping[Key, set[Key]] | None = None,
110+
group: bool = False,
111+
) -> dict[Key, int] | tuple[dict[Key, int], dict[int, list[Key]]]:
93112
if not dsk:
94113
return {}
95-
114+
groups = defaultdict(list)
115+
groups_by_key = dict()
116+
group_ix = 0
96117
dsk = dict(dsk)
97118

98119
if dependencies is None:
@@ -126,9 +147,14 @@ def _f(*args: Any, **kwargs: Any) -> None:
126147

127148
dsk[root] = (_f, *root_nodes)
128149
dependencies[root] = root_nodes
129-
o = order(dsk, dependencies)
130-
del o[root]
131-
return o
150+
if not group:
151+
o = order(dsk, dependencies, group=group)
152+
del o[root]
153+
return o
154+
else:
155+
o, g = order(dsk, dependencies, group=group)
156+
del o[root]
157+
return o, g
132158
root = list(root_nodes)[0]
133159
init_stack: dict[Key, tuple] | set[Key] | list[Key]
134160
# Leaf nodes. We choose one--the initial node--for each weakly connected subgraph.
@@ -243,6 +269,7 @@ def process_runnables(layers_loaded: int = 0) -> None:
243269
nonlocal i
244270
runnable_candidates = set_difference(set(runnable), seen)
245271
runnable_sorted = sorted(runnable_candidates, key=pkey_getitem, reverse=True)
272+
prev = None
246273
while runnable_sorted:
247274
task = runnable_sorted.pop()
248275
if task in runnable:
@@ -255,6 +282,11 @@ def process_runnables(layers_loaded: int = 0) -> None:
255282
next_nodes[pkey].add(task)
256283
continue
257284
result[task] = i
285+
# groups[group_ix].append(task)
286+
group_ix = groups_by_key[runnable.get(task, prev)]
287+
groups[group_ix].append(task)
288+
groups_by_key[task] = group_ix
289+
prev = task
258290
runnable.pop(task, None)
259291
i += 1
260292
deps = dependents[task]
@@ -287,6 +319,8 @@ def process_runnables(layers_loaded: int = 0) -> None:
287319
layers_loaded += 1
288320
continue
289321
result[item] = i
322+
groups[group_ix].append(item)
323+
groups_by_key[item] = group_ix
290324
runnable.pop(item, None)
291325
i += 1
292326
deps = dependents[item]
@@ -318,6 +352,8 @@ def process_runnables(layers_loaded: int = 0) -> None:
318352
inner_stack = list(dep_pools[pkey])
319353
inner_stack_pop = inner_stack.pop
320354
seen_update(inner_stack)
355+
if group_ix in groups and len(groups[group_ix]) > 1:
356+
group_ix += 1
321357
continue
322358
next_nodes[pkey].update(dep_pools[pkey])
323359
heappush(min_key_next_nodes, pkey)
@@ -336,6 +372,8 @@ def process_runnables(layers_loaded: int = 0) -> None:
336372
next_nodes.pop(min_key), key=dependents_key, reverse=True
337373
)
338374
inner_stack_pop = inner_stack.pop
375+
if group_ix in groups and len(groups[group_ix]) > 1:
376+
group_ix += 1
339377
seen_update(inner_stack)
340378
continue
341379

@@ -344,7 +382,9 @@ def process_runnables(layers_loaded: int = 0) -> None:
344382

345383
if len(result) == len(dsk):
346384
break
347-
385+
# Increasing here is very conservative
386+
if group_ix in groups and len(groups[group_ix]) > 1:
387+
group_ix += 1
348388
if not is_init_sorted:
349389
init_stack = set(init_stack)
350390
init_stack = set_difference(init_stack, result)
@@ -356,8 +396,10 @@ def process_runnables(layers_loaded: int = 0) -> None:
356396

357397
inner_stack = [init_stack.pop()] # type: ignore[call-overload]
358398
inner_stack_pop = inner_stack.pop
359-
360-
return result
399+
if group:
400+
return result, dict(groups)
401+
else:
402+
return result
361403

362404

363405
def graph_metrics(

dask/tests/test_order.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
@pytest.fixture(
1414
params=[
1515
"abcde",
16-
"edcba",
16+
# "edcba",
1717
]
1818
)
1919
def abcde(request):
@@ -751,6 +751,12 @@ def test_order_with_equal_dependents(abcde):
751751
}
752752
)
753753
o = order(dsk)
754+
755+
import inspect
756+
757+
from dask.base import visualize
758+
759+
visualize(dsk, filename=inspect.stack()[0][3], color="group")
754760
total = 0
755761
for x in abc:
756762
for i in range(len(abc)):
@@ -857,6 +863,11 @@ def test_terminal_node_backtrack():
857863
),
858864
}
859865
o = order(dsk)
866+
import inspect
867+
868+
from dask.base import visualize
869+
870+
visualize(dsk, filename=inspect.stack()[0][3], color="group")
860871
assert o[("a", 2)] < o[("a", 3)]
861872

862873

@@ -876,6 +887,11 @@ def test_array_store_final_order(tmpdir):
876887
dest = root.empty_like(name="dest", data=x, chunks=x.chunksize, overwrite=True)
877888
d = x.store(dest, lock=False, compute=False)
878889
o = order(d.dask)
890+
import inspect
891+
892+
from dask.base import visualize
893+
894+
visualize(d.dask, filename=inspect.stack()[0][3], color="group")
879895
# Find the lowest store. Dask starts here.
880896
stores = [k for k in o if isinstance(k, tuple) and k[0].startswith("store-map-")]
881897
first_store = min(stores, key=lambda k: o[k])
@@ -979,6 +995,11 @@ def test_eager_to_compute_dependent_to_free_parent():
979995
}
980996
dependencies, dependents = get_deps(dsk)
981997
o = order(dsk)
998+
import inspect
999+
1000+
from dask.base import visualize
1001+
1002+
visualize(dsk, filename=inspect.stack()[0][3], color="group")
9821003
parents = {deps.pop() for key, deps in dependents.items() if not dependencies[key]}
9831004

9841005
def cost(deps):
@@ -1013,7 +1034,11 @@ def test_diagnostics(abcde):
10131034
(e, 1): (f, (e, 0)),
10141035
}
10151036
o = order(dsk)
1037+
import inspect
10161038

1039+
from dask.base import visualize
1040+
1041+
visualize(dsk, filename=inspect.stack()[0][3], color="group")
10171042
assert o[(e, 1)] == len(dsk) - 1
10181043
assert o[(d, 1)] == len(dsk) - 2
10191044
assert o[(c, 1)] == len(dsk) - 3
@@ -1135,10 +1160,23 @@ def test_array_vs_dataframe(optimize):
11351160
quad = ds**2
11361161
quad["uv"] = ds.anom_u * ds.anom_v
11371162
mean = quad.mean("time")
1163+
dsk = collections_to_dsk([mean], optimize_graph=optimize)
1164+
o, g = order(dsk, group=True)
1165+
print(len(g))
1166+
print({ix: len(keys) for ix, keys in g.items()})
11381167
diag_array = diagnostics(collections_to_dsk([mean], optimize_graph=optimize))
11391168
diag_df = diagnostics(
11401169
collections_to_dsk([mean.to_dask_dataframe()], optimize_graph=optimize)
11411170
)
1171+
import inspect
1172+
1173+
from dask.base import visualize
1174+
1175+
visualize(
1176+
collections_to_dsk([mean], optimize_graph=optimize),
1177+
filename=inspect.stack()[0][3],
1178+
color="group",
1179+
)
11421180
assert max(diag_df[1]) == max(diag_array[1])
11431181
assert max(diag_array[1]) < 50
11441182

@@ -1218,6 +1256,11 @@ def test_anom_mean_raw():
12181256
}
12191257

12201258
o = order(dsk)
1259+
import inspect
1260+
1261+
from dask.base import visualize
1262+
1263+
visualize(dsk, filename=inspect.stack()[0][3], color="group")
12211264
# The left hand computation branch should complete before we start loading
12221265
# more data
12231266
nodes_to_finish_before_loading_more_data = [
@@ -1701,6 +1744,11 @@ def random(**kwargs):
17011744
dependencies, dependents = get_deps(dsk)
17021745
# Verify assumptions
17031746
o = order(dsk)
1747+
import inspect
1748+
1749+
from dask.base import visualize
1750+
1751+
visualize(dsk, filename=inspect.stack()[0][3], color="group")
17041752
# Verify assumptions (specifically that the reducers are sum-aggregate)
17051753
assert {key_split(k) for k in o} == {"object", "sum", "sum-aggregate"}
17061754

0 commit comments

Comments
 (0)