Skip to content

Commit 67e073f

Browse files
authored
Don't pile up context_meter callbacks (#7961)
1 parent fca4b35 commit 67e073f

4 files changed

Lines changed: 126 additions & 24 deletions

File tree

distributed/metrics.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ class ContextMeter:
189189
A->B comms: network-write 0.567 seconds
190190
"""
191191

192-
_callbacks: ContextVar[list[Callable[[Hashable, float, str], None]]]
192+
_callbacks: ContextVar[dict[Hashable, Callable[[Hashable, float, str], None]]]
193193

194194
def __init__(self):
195-
self._callbacks = ContextVar(f"MetricHook<{id(self)}>._callbacks", default=[])
195+
self._callbacks = ContextVar(f"MetricHook<{id(self)}>._callbacks", default={})
196196

197197
def __reduce__(self):
198198
assert self is context_meter, "Found copy of singleton"
@@ -204,13 +204,28 @@ def _unpickle_singleton():
204204

205205
@contextmanager
206206
def add_callback(
207-
self, callback: Callable[[Hashable, float, str], None]
207+
self,
208+
callback: Callable[[Hashable, float, str], None],
209+
*,
210+
key: Hashable | None = None,
208211
) -> Iterator[None]:
209212
"""Add a callback when entering the context and remove it when exiting it.
210213
The callback must accept the same parameters as :meth:`digest_metric`.
214+
215+
Parameters
216+
----------
217+
callback: Callable
218+
``f(label, value, unit)`` to be executed
219+
key: Hashable, optional
220+
Unique key for the callback. If two nested calls to ``add_callback`` use the
221+
same key, suppress the outermost callback.
211222
"""
223+
if key is None:
224+
key = object()
212225
cbs = self._callbacks.get()
213-
tok = self._callbacks.set(cbs + [callback])
226+
cbs = cbs.copy()
227+
cbs[key] = callback
228+
tok = self._callbacks.set(cbs)
214229
try:
215230
yield
216231
finally:
@@ -221,7 +236,7 @@ def digest_metric(self, label: Hashable, value: float, unit: str) -> None:
221236
metric.
222237
"""
223238
cbs = self._callbacks.get()
224-
for cb in cbs:
239+
for cb in cbs.values():
225240
cb(label, value, unit)
226241

227242
@contextmanager
@@ -234,9 +249,10 @@ def meter(
234249
) -> Iterator[MeterOutput]:
235250
"""Convenience context manager or decorator which calls func() before and after
236251
the wrapped code, calculates the delta, and finally calls :meth:`digest_metric`.
237-
It also subtracts any other calls to :meth:`meter` or :meth:`digest_metric` with
238-
the same unit performed within the context, so that the total is strictly
239-
additive.
252+
253+
If unit=='seconds', it also subtracts any other calls to :meth:`meter` or
254+
:meth:`digest_metric` with the same unit performed within the context, so that
255+
the total is strictly additive.
240256
241257
Parameters
242258
----------
@@ -256,10 +272,19 @@ def meter(
256272
nested calls to :meth:`meter`, then delta (for seconds only) is reduced by the
257273
inner metrics, to a minimum of ``floor``.
258274
"""
275+
if unit != "seconds":
276+
try:
277+
with meter(func, floor=floor) as m:
278+
yield m
279+
finally:
280+
self.digest_metric(label, m.delta, unit)
281+
return
282+
283+
# If unit=="seconds", subtract time metered from the sub-contexts
259284
offsets = []
260285

261286
def callback(label2: Hashable, value2: float, unit2: str) -> None:
262-
if unit2 == unit == "seconds":
287+
if unit2 == unit:
263288
# This must be threadsafe to support callbacks invoked from
264289
# distributed.utils.offload; '+=' on a float would not be threadsafe!
265290
offsets.append(value2)
@@ -316,14 +341,20 @@ def __init__(self, func: Callable[[], float] = timemod.perf_counter):
316341
self.start = func()
317342
self.metrics = []
318343

344+
def _callback(self, label: Hashable, value: float, unit: str) -> None:
345+
self.metrics.append((label, value, unit))
346+
319347
@contextmanager
320-
def record(self) -> Iterator[None]:
348+
def record(self, *, key: Hashable | None = None) -> Iterator[None]:
321349
"""Ingest metrics logged with :meth:`ContextMeter.digest_metric` or
322350
:meth:`ContextMeter.meter` and temporarily store them in :ivar:`metrics`.
351+
352+
Parameters
353+
----------
354+
key: Hashable, optional
355+
See :meth:`ContextMeter.add_callback`
323356
"""
324-
with context_meter.add_callback(
325-
lambda label, value, unit: self.metrics.append((label, value, unit))
326-
):
357+
with context_meter.add_callback(self._callback, key=key):
327358
yield
328359

329360
def finalize(

distributed/tests/test_metrics.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,34 @@ def test_meter_floor(kwargs, delta):
7777

7878

7979
def test_context_meter():
80-
it = iter([123, 124])
80+
it = iter([123, 124, 125, 126])
8181
cbs = []
8282

8383
with metrics.context_meter.add_callback(lambda l, v, u: cbs.append((l, v, u))):
84-
with metrics.context_meter.meter("m1", func=lambda: next(it)) as m:
85-
assert m.start == 123
86-
assert math.isnan(m.stop)
87-
assert math.isnan(m.delta)
84+
with metrics.context_meter.meter("m1", func=lambda: next(it)) as m1:
85+
assert m1.start == 123
86+
assert math.isnan(m1.stop)
87+
assert math.isnan(m1.delta)
88+
with metrics.context_meter.meter("m2", func=lambda: next(it), unit="foo") as m2:
89+
assert m2.start == 125
90+
assert math.isnan(m2.stop)
91+
assert math.isnan(m2.delta)
92+
8893
metrics.context_meter.digest_metric("m1", 2, "seconds")
8994
metrics.context_meter.digest_metric("m1", 1, "foo")
9095

9196
# Not recorded - out of context
9297
metrics.context_meter.digest_metric("m1", 123, "foo")
9398

94-
assert m.start == 123
95-
assert m.stop == 124
96-
assert m.delta == 1
99+
assert m1.start == 123
100+
assert m1.stop == 124
101+
assert m1.delta == 1
102+
assert m2.start == 125
103+
assert m2.stop == 126
104+
assert m2.delta == 1
97105
assert cbs == [
98106
("m1", 1, "seconds"),
107+
("m2", 1, "foo"),
99108
("m1", 2, "seconds"),
100109
("m1", 1, "foo"),
101110
]
@@ -199,3 +208,43 @@ def test_delayed_metrics_ledger():
199208
("foo", 10, "bytes"),
200209
("other", 20, "seconds"),
201210
]
211+
212+
213+
def test_context_meter_keyed():
214+
cbs = []
215+
216+
def cb(tag, key):
217+
return metrics.context_meter.add_callback(
218+
lambda l, v, u: cbs.append((tag, l)), key=key
219+
)
220+
221+
with cb("x", key="x"), cb("y", key="y"):
222+
metrics.context_meter.digest_metric("l1", 1, "u")
223+
with cb("z", key="x"):
224+
metrics.context_meter.digest_metric("l2", 2, "u")
225+
metrics.context_meter.digest_metric("l3", 3, "u")
226+
227+
assert cbs == [
228+
("x", "l1"),
229+
("y", "l1"),
230+
("z", "l2"),
231+
("y", "l2"),
232+
("x", "l3"),
233+
("y", "l3"),
234+
]
235+
236+
237+
def test_delayed_metrics_ledger_keyed():
238+
l1 = metrics.DelayedMetricsLedger()
239+
l2 = metrics.DelayedMetricsLedger()
240+
l3 = metrics.DelayedMetricsLedger()
241+
242+
with l1.record(key="x"), l2.record(key="y"):
243+
metrics.context_meter.digest_metric("l1", 1, "u")
244+
with l3.record(key="x"):
245+
metrics.context_meter.digest_metric("l2", 2, "u")
246+
metrics.context_meter.digest_metric("l3", 3, "u")
247+
248+
assert l1.metrics == [("l1", 1, "u"), ("l3", 3, "u")]
249+
assert l2.metrics == [("l1", 1, "u"), ("l2", 2, "u"), ("l3", 3, "u")]
250+
assert l3.metrics == [("l2", 2, "u")]

distributed/tests/test_worker_metrics.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,3 +618,22 @@ async def test_new_metrics_during_heartbeat(c, s, a):
618618
assert a.digests_total["execute", span.id, "x", "test", "test"] == n
619619
assert s.cumulative_worker_metrics["execute", "x", "test", "test"] == n
620620
assert span.cumulative_worker_metrics["execute", "x", "test", "test"] == n
621+
622+
623+
@gen_cluster(
624+
client=True,
625+
nthreads=[("", 1)],
626+
config={"distributed.scheduler.worker-saturation": float("inf")},
627+
)
628+
async def test_delayed_ledger_is_not_reentrant(c, s, a):
629+
"""https://github.com/dask/distributed/issues/7949
630+
631+
Test that, when there's a long chain of task done -> task start events,
632+
the callbacks added by the delayed ledger don't pile up on top of each other.
633+
"""
634+
635+
def f(_):
636+
return len(context_meter._callbacks.get())
637+
638+
out = await c.gather(c.map(f, range(1000)))
639+
assert max(out) < 10

distributed/worker_state_machine.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3635,7 +3635,7 @@ def _start_async_instruction( # type: ignore[valid-type]
36353635

36363636
@wraps(func)
36373637
async def wrapper() -> StateMachineEvent:
3638-
with ledger.record():
3638+
with ledger.record(key="async-instruction"):
36393639
return await func(*args, **kwargs)
36403640

36413641
task = asyncio.create_task(wrapper(), name=task_name)
@@ -3664,8 +3664,11 @@ def _finish_async_instruction(
36643664
logger.exception("async instruction handlers should never raise!")
36653665
raise
36663666

3667-
with ledger.record():
3668-
# Capture metric events in _transition_to_memory()
3667+
# Capture metric events in _transition_to_memory()
3668+
# As this may trigger calls to _start_async_instruction for more tasks,
3669+
# make sure we don't endlessly pile up context_meter callbacks by specifying
3670+
# the same key as in _start_async_instruction.
3671+
with ledger.record(key="async-instruction"):
36693672
self.handle_stimulus(stim)
36703673

36713674
self._finalize_metrics(stim, ledger, span_id)

0 commit comments

Comments
 (0)