Skip to content

Commit 74599ee

Browse files
committed
Merge branch 'main' into shuffle-shutdown
2 parents f83bc40 + 7b21399 commit 74599ee

8 files changed

Lines changed: 211 additions & 27 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ repos:
5858
- types-psutil
5959
- types-setuptools
6060
# Typed libraries
61-
- click
61+
- click!=8.1.4 # https://github.com/pallets/click/issues/2558
6262
- numpy
6363
- pytest
6464
- tornado

distributed/diagnostics/progress.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
import warnings
56
from collections import defaultdict
67
from timeit import default_timer
78
from typing import ClassVar
@@ -140,9 +141,19 @@ def stop(self, exception=None, key=None):
140141
class MultiProgress(Progress):
141142
"""Progress variant that keeps track of different groups of keys
142143
143-
See Progress for most details. This only adds a function ``func=``
144-
that splits keys. This defaults to ``key_split`` which aligns with naming
145-
conventions chosen in the dask project (tuples, hyphens, etc..)
144+
See Progress for most details.
145+
146+
Parameters
147+
----------
148+
149+
func : Callable (deprecated)
150+
Function that splits keys. This defaults to ``key_split`` which
151+
aligns with naming conventions chosen in the dask project (tuples,
152+
hyphens, etc..)
153+
154+
group_by : Callable | Literal["spans"] | Literal["prefix"], default: "prefix"
155+
How to group keys to display multiple bars. Defaults to "prefix",
156+
which uses ``key_split`` from dask project
146157
147158
State
148159
-----
@@ -161,10 +172,24 @@ class MultiProgress(Progress):
161172
"""
162173

163174
def __init__(
164-
self, keys, scheduler=None, func=key_split, minimum=0, dt=0.1, complete=False
175+
self,
176+
keys,
177+
scheduler=None,
178+
*,
179+
func=None,
180+
group_by="prefix",
181+
minimum=0,
182+
dt=0.1,
183+
complete=False,
165184
):
166-
self.func = func
167-
name = f"multi-progress-{tokenize(keys, func, minimum, dt, complete)}"
185+
if func is not None:
186+
warnings.warn(
187+
"`func` is deprecated, use `group_by`", category=DeprecationWarning
188+
)
189+
group_by = func
190+
self.group_by = key_split if group_by in (None, "prefix") else group_by
191+
self.func = None
192+
name = f"multi-progress-{tokenize(keys, group_by, minimum, dt, complete)}"
168193
super().__init__(
169194
keys, scheduler, minimum=minimum, dt=dt, complete=complete, name=name
170195
)
@@ -191,6 +216,22 @@ async def setup(self):
191216
if not self.keys:
192217
self.stop(exception=None, key=None)
193218

219+
if self.group_by == "spans":
220+
spans_ext = self.scheduler.extensions["spans"]
221+
span_defs = spans_ext.spans if spans_ext else None
222+
223+
def group_key(k):
224+
span_id = self.scheduler.tasks[k].group.span_id
225+
span_name = ", ".join(span_defs[span_id].name) if span_defs else span_id
226+
return span_name, span_id
227+
228+
group_keys = {k: group_key(k) for k in self.all_keys}
229+
self.func = group_keys.get
230+
elif self.group_by == "prefix":
231+
self.func = key_split
232+
else:
233+
self.func = self.group_by
234+
194235
# Group keys by func name
195236
self.keys = valmap(set, groupby(self.func, self.keys))
196237
self.all_keys = valmap(set, groupby(self.func, self.all_keys))

distributed/diagnostics/progressbar.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import weakref
88
from contextlib import suppress
99
from timeit import default_timer
10+
from typing import Callable
1011

1112
from tlz import valmap
1213
from tornado.ioloop import IOLoop
@@ -243,7 +244,9 @@ def __init__(
243244
self,
244245
keys,
245246
scheduler=None,
246-
func=key_split,
247+
*,
248+
func=None,
249+
group_by="prefix",
247250
interval="100ms",
248251
complete=False,
249252
**kwargs,
@@ -256,8 +259,17 @@ def __init__(
256259
self.client = weakref.ref(key.client)
257260
break
258261

262+
if func is not None:
263+
warnings.warn(
264+
"`func` is deprecated, use `group_by` instead",
265+
category=DeprecationWarning,
266+
)
267+
group_by = func
268+
elif group_by in (None, "prefix"):
269+
group_by = key_split
270+
259271
self.keys = {k.key if hasattr(k, "key") else k for k in keys}
260-
self.func = func
272+
self.group_by = group_by
261273
self.interval = interval
262274
self.complete = complete
263275
self._start_time = default_timer()
@@ -269,10 +281,15 @@ def elapsed(self):
269281
async def listen(self):
270282
complete = self.complete
271283
keys = self.keys
272-
func = self.func
284+
group_by = self.group_by
273285

274286
async def setup(scheduler):
275-
p = MultiProgress(keys, scheduler, complete=complete, func=func)
287+
p = MultiProgress(
288+
keys,
289+
scheduler,
290+
complete=complete,
291+
group_by=group_by,
292+
)
276293
await p.setup()
277294
return p
278295

@@ -339,29 +356,31 @@ def __init__(
339356
keys,
340357
scheduler=None,
341358
minimum=0,
342-
interval=0.1,
343-
func=key_split,
344-
complete=False,
345359
**kwargs,
346360
):
347-
super().__init__(keys, scheduler, func, interval, complete)
361+
super().__init__(keys, scheduler, **kwargs)
348362
from ipywidgets import VBox
349363

350364
self.widget = VBox([])
351365

352366
def make_widget(self, all):
353367
from ipywidgets import HTML, FloatProgress, HBox, VBox
354368

369+
def make_label(key):
370+
if isinstance(key, tuple):
371+
# tuple of (group_name, group_id)
372+
key = key[0]
373+
key = key.decode() if isinstance(key, bytes) else key
374+
return html.escape(key)
375+
355376
self.elapsed_time = HTML("")
356377
self.bars = {key: FloatProgress(min=0, max=1, description="") for key in all}
357378
self.bar_texts = {key: HTML("") for key in all}
358379
self.bar_labels = {
359380
key: HTML(
360381
'<div style="padding: 0px 10px 0px 10px;'
361382
" text-align:left; word-wrap: "
362-
'break-word;">'
363-
+ html.escape(key.decode() if isinstance(key, bytes) else key)
364-
+ "</div>"
383+
'break-word;">' + make_label(key) + "</div>"
365384
)
366385
for key in all
367386
}
@@ -429,7 +448,9 @@ def _draw_bar(self, remaining, all, status, **kwargs):
429448
)
430449

431450

432-
def progress(*futures, notebook=None, multi=True, complete=True, **kwargs):
451+
def progress(
452+
*futures, notebook=None, multi=True, complete=True, group_by="prefix", **kwargs
453+
):
433454
"""Track progress of futures
434455
435456
This operates differently in the notebook and the console
@@ -448,6 +469,9 @@ def progress(*futures, notebook=None, multi=True, complete=True, **kwargs):
448469
complete : bool (optional)
449470
Track all keys (True) or only keys that have not yet run (False)
450471
(defaults to True)
472+
group_by : Callable | Literal["spans"] | Literal["prefix"]
473+
Use spans instead of task key names for grouping tasks
474+
(defaults to "prefix")
451475
452476
Notes
453477
-----
@@ -465,9 +489,18 @@ def progress(*futures, notebook=None, multi=True, complete=True, **kwargs):
465489
futures = [futures]
466490
if notebook is None:
467491
notebook = is_kernel() # often but not always correct assumption
492+
if kwargs.get("func", None) is not None:
493+
warnings.warn(
494+
"`func` is deprecated, use `group_by` instead", category=DeprecationWarning
495+
)
496+
group_by = kwargs.pop("func")
497+
if group_by not in ("spans", "prefix") and not isinstance(group_by, Callable):
498+
raise ValueError("`group_by` should be 'spans', 'prefix', or a Callable")
468499
if notebook:
469500
if multi:
470-
bar = MultiProgressWidget(futures, complete=complete, **kwargs)
501+
bar = MultiProgressWidget(
502+
futures, complete=complete, group_by=group_by, **kwargs
503+
)
471504
else:
472505
bar = ProgressWidget(futures, complete=complete, **kwargs)
473506
return bar

distributed/diagnostics/tests/test_progress.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66

7+
import distributed
78
from distributed import Nanny
89
from distributed.client import wait
910
from distributed.compatibility import LINUX
@@ -14,7 +15,16 @@
1415
Progress,
1516
SchedulerPlugin,
1617
)
17-
from distributed.utils_test import dec, div, gen_cluster, inc, nodebug, slowdec, slowinc
18+
from distributed.utils_test import (
19+
dec,
20+
div,
21+
gen_cluster,
22+
inc,
23+
nodebug,
24+
slowdec,
25+
slowinc,
26+
wait_for_state,
27+
)
1828

1929

2030
def f(*args):
@@ -72,6 +82,49 @@ async def test_multiprogress(c, s, a, b):
7282
assert p.status == "finished"
7383

7484

85+
@gen_cluster(client=True)
86+
async def test_multiprogress_cancel(c, s, a, b):
87+
lock = distributed.Lock()
88+
await lock.acquire()
89+
90+
async def wait_and_raise(*args, **kwargs):
91+
async with lock:
92+
raise RuntimeError()
93+
94+
f = c.submit(wait_and_raise, key="cancel", workers=[a.address])
95+
p = MultiProgress([f], scheduler=s, complete=True)
96+
await p.setup()
97+
await wait_for_state(f.key, "executing", a)
98+
f.release()
99+
await wait_for_state(f.key, "cancelled", a)
100+
assert p.status == "error"
101+
assert p.all_keys.keys() == {"cancel"}
102+
103+
104+
@gen_cluster(client=True)
105+
async def test_multiprogress_with_spans(c, s, a, b):
106+
x = c.submit(inc, 1)
107+
p = MultiProgress([x], scheduler=s, complete=True, group_by="spans")
108+
await p.setup()
109+
group_names = {k[0] for k in p.all_keys}
110+
assert group_names == {"default"}
111+
112+
113+
@gen_cluster(client=True)
114+
async def test_multiprogress_with_prefix(c, s, a, b):
115+
x = c.submit(inc, 1)
116+
p = MultiProgress([x], scheduler=s, complete=True, group_by="prefix")
117+
await p.setup()
118+
group_names = {k for k in p.all_keys}
119+
assert group_names == {"inc"}
120+
121+
122+
def test_multiprogress_warns():
123+
with pytest.warns(DeprecationWarning, match="func` is deprecated, use `group_by"):
124+
p = MultiProgress([], complete=True, func="spans")
125+
assert p.group_by == "spans"
126+
127+
75128
@gen_cluster(client=True)
76129
async def test_robust_to_bad_plugin(c, s, a, b):
77130
class Bad(SchedulerPlugin):

distributed/diagnostics/tests/test_progressbar.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ def test_progress_function_w_kwargs(client, capsys):
7777
check_bar_completed(capsys)
7878

7979

80+
def test_progress_function_warns(client):
81+
with pytest.warns(DeprecationWarning, match="`func` is deprecated"):
82+
progress(None, func="prefix")
83+
84+
85+
def test_progress_function_raises():
86+
with pytest.raises(ValueError, match="`group_by` should be "):
87+
progress(None, group_by="incorrect")
88+
89+
8090
@gen_cluster(client=True, nthreads=[])
8191
async def test_deprecated_loop_properties(c, s):
8292
class ExampleTextProgressBar(TextProgressBar):

distributed/diagnostics/tests/test_widgets.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import pytest
99
from packaging.version import parse as parse_version
1010

11+
from dask.utils import key_split
12+
1113
from distributed.client import wait
14+
from distributed.spans import span
1215
from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws
1316

1417
ipywidgets = pytest.importorskip("ipywidgets")
@@ -207,6 +210,7 @@ async def test_multibar_complete(c, s, a, b):
207210
y2 = c.submit(dec, y1, key="y-2")
208211
e = c.submit(throws, y2, key="e")
209212
other = c.submit(inc, 123, key="other")
213+
await other.cancel()
210214

211215
p = MultiProgressWidget([e.key], scheduler=s.address, complete=True)
212216
await p.listen()
@@ -227,6 +231,51 @@ def test_fast(client):
227231
assert set(p._last_response["all"]) == {"inc", "dec", "add"}
228232

229233

234+
@mock_widget()
235+
def test_multibar_with_spans(client):
236+
"""Test progress(group_by='spans'"""
237+
with span("span 1"):
238+
L = client.map(inc, range(100))
239+
with span("span 2"):
240+
L2 = client.map(dec, L)
241+
with span("span 3"):
242+
L3 = client.map(add, L, L2)
243+
with span("other span"):
244+
_ = client.submit(inc, 123)
245+
e = client.submit(throws, L3)
246+
247+
p = progress(e, complete=True, multi=True, notebook=True, group_by="spans")
248+
client.sync(p.listen)
249+
250+
# keys are tuples of (group_name, group_id), just get names
251+
bar_items = {k[0]: v.value for k, v in p.bars.items()}
252+
bar_texts = {k[0]: v.value for k, v in p.bar_texts.items()}
253+
bar_labels = {k[0]: v.value for k, v in p.bar_labels.items()}
254+
255+
assert bar_items == {"span 1": 1, "span 2": 1, "span 3": 1, "default": 0}
256+
assert bar_texts.keys() == {"span 1", "span 2", "span 3", "default"}
257+
assert all("100 / 100" in v for k, v in bar_texts.items() if k != "default")
258+
assert bar_labels.keys() == {"span 1", "span 2", "span 3", "default"}
259+
assert all(f">{k}<" in v for k, v in bar_labels.items())
260+
261+
262+
@mock_widget()
263+
def test_multibar_func_warns(client):
264+
"""Deprecate `func`, use `group_by`"""
265+
L = client.map(inc, range(100))
266+
L2 = client.map(dec, L)
267+
L3 = client.map(add, L, L2)
268+
269+
# ensure default value if nothing is set
270+
p = MultiProgressWidget(L3)
271+
assert p.group_by == key_split
272+
273+
with pytest.warns(
274+
DeprecationWarning, match="`func` is deprecated, use `group_by` instead"
275+
):
276+
MultiProgressWidget(L3, func="foo")
277+
278+
230279
@mock_widget()
231280
@gen_cluster(client=True, client_kwargs={"serializers": ["msgpack"]})
232281
async def test_serializers(c, s, a, b):

0 commit comments

Comments
 (0)