Skip to content

Commit b9ebe23

Browse files
committed
cancelled_long_running
1 parent 817ead3 commit b9ebe23

4 files changed

Lines changed: 201 additions & 21 deletions

File tree

distributed/scheduler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4854,7 +4854,7 @@ def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
48544854
self.transitions({key: "released"}, stimulus_id)
48554855

48564856
def handle_long_running(
4857-
self, key: str, worker: str, compute_duration: float, stimulus_id: str
4857+
self, key: str, worker: str, compute_duration: float | None, stimulus_id: str
48584858
) -> None:
48594859
"""A task has seceded from the thread pool
48604860
@@ -4874,11 +4874,12 @@ def handle_long_running(
48744874
logger.debug("Received long-running signal from duplicate task. Ignoring.")
48754875
return
48764876

4877-
old_duration = ts.prefix.duration_average
4878-
if old_duration < 0:
4879-
ts.prefix.duration_average = compute_duration
4880-
else:
4881-
ts.prefix.duration_average = (old_duration + compute_duration) / 2
4877+
if compute_duration is not None:
4878+
old_duration = ts.prefix.duration_average
4879+
if old_duration < 0:
4880+
ts.prefix.duration_average = compute_duration
4881+
else:
4882+
ts.prefix.duration_average = (old_duration + compute_duration) / 2
48824883

48834884
occ = ws.processing[ts]
48844885
ws.occupancy -= occ

distributed/tests/test_cancelled_state.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
GatherDepFailureEvent,
3131
GatherDepNetworkFailureEvent,
3232
GatherDepSuccessEvent,
33+
LongRunningMsg,
3334
RescheduleEvent,
35+
SecedeEvent,
3436
TaskFinishedMsg,
3537
UpdateDataEvent,
3638
)
@@ -640,7 +642,12 @@ def test_workerstate_executing_to_executing(ws_with_running_task):
640642
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
641643
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"),
642644
)
643-
assert not instructions
645+
if prev_state == "executing":
646+
assert not instructions
647+
else:
648+
assert instructions == [
649+
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
650+
]
644651
assert ws.tasks["x"] is ts
645652
assert ts.state == prev_state
646653

@@ -821,7 +828,12 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task):
821828
FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"),
822829
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"),
823830
)
824-
assert not instructions
831+
if prev_state == "executing":
832+
assert not instructions
833+
else:
834+
assert instructions == [
835+
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4")
836+
]
825837
assert ws.tasks["x"].state == prev_state
826838

827839

@@ -946,3 +958,102 @@ def test_cancel_with_dependencies_in_memory(ws, release_dep, done_ev_cls):
946958
ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5"))
947959
assert "y" not in ws.tasks
948960
assert ws.tasks["x"].state == "memory"
961+
962+
963+
@pytest.mark.parametrize("resume_to_fetch", [False, True])
964+
@pytest.mark.parametrize("resume_to_executing", [False, True])
965+
@pytest.mark.parametrize(
966+
"done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent, RescheduleEvent]
967+
)
968+
def test_secede_cancelled_or_resumed_workerstate(
969+
ws, resume_to_fetch, resume_to_executing, done_ev_cls
970+
):
971+
"""Test what happens when a cancelled or resumed(fetch) task calls secede().
972+
See also test_secede_cancelled_or_resumed_scheduler
973+
"""
974+
ws2 = "127.0.0.1:2"
975+
ws.handle_stimulus(
976+
ComputeTaskEvent.dummy("x", stimulus_id="s1"),
977+
FreeKeysEvent(keys=["x"], stimulus_id="s2"),
978+
)
979+
if resume_to_fetch:
980+
ws.handle_stimulus(
981+
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"),
982+
)
983+
ts = ws.tasks["x"]
984+
assert ts.previous == "executing"
985+
assert ts in ws.executing
986+
assert ts not in ws.long_running
987+
988+
instructions = ws.handle_stimulus(
989+
SecedeEvent(key="x", compute_duration=1, stimulus_id="s4")
990+
)
991+
assert not instructions # Do not send RescheduleMsg
992+
assert ts.previous == "long-running"
993+
assert ts not in ws.executing
994+
assert ts in ws.long_running
995+
996+
if resume_to_executing:
997+
instructions = ws.handle_stimulus(
998+
FreeKeysEvent(keys=["y"], stimulus_id="s5"),
999+
ComputeTaskEvent.dummy("x", stimulus_id="s6"),
1000+
)
1001+
# Inform the scheduler of the SecedeEvent that happened in the past
1002+
assert instructions == [
1003+
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6")
1004+
]
1005+
assert ts.state == "long-running"
1006+
assert ts not in ws.executing
1007+
assert ts in ws.long_running
1008+
1009+
ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s7"))
1010+
assert ts not in ws.executing
1011+
assert ts not in ws.long_running
1012+
1013+
1014+
@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
1015+
async def test_secede_cancelled_or_resumed_scheduler(c, s, a):
1016+
"""Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction
1017+
with the scheduler
1018+
"""
1019+
ws = s.workers[a.address]
1020+
ev1 = Event()
1021+
ev2 = Event()
1022+
ev3 = Event()
1023+
ev4 = Event()
1024+
1025+
def f(ev1, ev2, ev3, ev4):
1026+
ev1.set()
1027+
ev2.wait()
1028+
distributed.secede()
1029+
ev3.set()
1030+
ev4.wait()
1031+
return 123
1032+
1033+
x = c.submit(f, ev1, ev2, ev3, ev4, key="x")
1034+
await ev1.wait()
1035+
ts = a.state.tasks["x"]
1036+
assert ts.state == "executing"
1037+
assert sum(ws.processing.values()) > 0
1038+
1039+
x.release()
1040+
await wait_for_state("x", "cancelled", a)
1041+
assert not ws.processing
1042+
1043+
await ev2.set()
1044+
await ev3.wait()
1045+
assert ts.previous == "long-running"
1046+
assert not ws.processing
1047+
1048+
x = c.submit(inc, 1, key="x")
1049+
await wait_for_state("x", "long-running", a)
1050+
1051+
# Test that the scheduler receives a delayed {op: long-running}
1052+
assert ws.processing
1053+
while sum(ws.processing.values()):
1054+
await asyncio.sleep(0.1)
1055+
assert ws.processing
1056+
1057+
await ev4.set()
1058+
assert await x == 123
1059+
assert not ws.processing

distributed/tests/test_resources.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ExecuteFailureEvent,
1818
ExecuteSuccessEvent,
1919
FreeKeysEvent,
20+
LongRunningMsg,
2021
RescheduleEvent,
2122
TaskFinishedMsg,
2223
)
@@ -565,14 +566,21 @@ def test_resumed_with_different_resources(ws_with_running_task, done_ev_cls):
565566
"""
566567
ws = ws_with_running_task
567568
assert ws.available_resources == {"R": 0}
569+
ts = ws.tasks["x"]
570+
prev_state = ts.state
568571

569572
ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
570573
assert ws.available_resources == {"R": 0}
571574

572575
instructions = ws.handle_stimulus(
573576
ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.4})
574577
)
575-
assert not instructions
578+
if prev_state == "long-running":
579+
assert instructions == [
580+
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
581+
]
582+
else:
583+
assert not instructions
576584
assert ws.available_resources == {"R": 0}
577585

578586
ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s3"))

distributed/worker_state_machine.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ class LongRunningMsg(SendMessageToScheduler):
517517

518518
__slots__ = ("key", "compute_duration")
519519
key: str
520-
compute_duration: float
520+
compute_duration: float | None
521521

522522

523523
@dataclass
@@ -2077,24 +2077,39 @@ def _transition_resumed_waiting(
20772077
See also
20782078
--------
20792079
_transition_cancelled_fetch
2080+
_transition_cancelled_or_resumed_long_running
20802081
_transition_cancelled_waiting
20812082
_transition_resumed_fetch
20822083
"""
20832084
# None of the exit events of execute or gather_dep recommend a transition to
20842085
# waiting
20852086
assert not ts.done
2086-
if ts.previous in ("executing", "long-running"):
2087+
if ts.previous == "executing":
20872088
assert ts.next == "fetch"
20882089
# We're back where we started. We should forget about the entire
20892090
# cancellation attempt
2090-
ts.state = ts.previous
2091+
ts.state = "executing"
20912092
ts.next = None
20922093
ts.previous = None
2093-
elif self.validate:
2094+
return {}, []
2095+
2096+
elif ts.previous == "long-running":
2097+
assert ts.next == "fetch"
2098+
# Same as executing, and in addition send the LongRunningMsg in arrears
2099+
# Note that, if the task seceded before it was cancelled, this will cause
2100+
# the message to be sent twice.
2101+
ts.state = "long-running"
2102+
ts.next = None
2103+
ts.previous = None
2104+
smsg = LongRunningMsg(
2105+
key=ts.key, compute_duration=None, stimulus_id=stimulus_id
2106+
)
2107+
return {}, [smsg]
2108+
2109+
else:
20942110
assert ts.previous == "flight"
20952111
assert ts.next == "waiting"
2096-
2097-
return {}, []
2112+
return {}, []
20982113

20992114
def _transition_cancelled_fetch(
21002115
self, ts: TaskState, *, stimulus_id: str
@@ -2131,17 +2146,29 @@ def _transition_cancelled_waiting(
21312146
See also
21322147
--------
21332148
_transition_cancelled_fetch
2149+
_transition_cancelled_or_resumed_long_running
21342150
_transition_resumed_fetch
21352151
_transition_resumed_waiting
21362152
"""
21372153
# None of the exit events of gather_dep or execute recommend a transition to
21382154
# waiting
21392155
assert not ts.done
2140-
if ts.previous in ("executing", "long-running"):
2156+
if ts.previous == "executing":
21412157
# Forget the task was cancelled to begin with
2142-
ts.state = ts.previous
2158+
ts.state = "executing"
21432159
ts.previous = None
21442160
return {}, []
2161+
elif ts.previous == "long-running":
2162+
# Forget the task was cancelled to begin with, and inform the scheduler
2163+
# in arrears that it has seceded.
2164+
# Note that, if the task seceded before it was cancelled, this will cause
2165+
# the message to be sent twice.
2166+
ts.state = "long-running"
2167+
ts.previous = None
2168+
smsg = LongRunningMsg(
2169+
key=ts.key, compute_duration=None, stimulus_id=stimulus_id
2170+
)
2171+
return {}, [smsg]
21452172
else:
21462173
assert ts.previous == "flight"
21472174
ts.state = "resumed"
@@ -2234,6 +2261,11 @@ def _transition_flight_released(
22342261
def _transition_executing_long_running(
22352262
self, ts: TaskState, compute_duration: float, *, stimulus_id: str
22362263
) -> RecsInstrs:
2264+
"""
2265+
See also
2266+
--------
2267+
_transition_cancelled_or_resumed_long_running
2268+
"""
22372269
ts.state = "long-running"
22382270
self.executing.discard(ts)
22392271
self.long_running.add(ts)
@@ -2246,6 +2278,34 @@ def _transition_executing_long_running(
22462278
self._ensure_computing(),
22472279
)
22482280

2281+
def _transition_cancelled_or_resumed_long_running(
2282+
self, ts: TaskState, compute_duration: float, *, stimulus_id: str
2283+
) -> RecsInstrs:
2284+
"""Handles transitions:
2285+
2286+
- cancelled(executing) -> long-running
2287+
- cancelled(long-running) -> long-running (user called secede() twice)
2288+
- resumed(executing->fetch) -> long-running
2289+
- resumed(long-running->fetch) -> long-running (user called secede() twice)
2290+
2291+
Unlike in the executing->long_running transition, do not send LongRunningMsg.
2292+
From the scheduler's perspective, this task no longer exists (cancelled) or is
2293+
in memory on another worker (resumed). So it shouldn't hear about it.
2294+
Instead, we're going to send the LongRunningMsg when and if the task
2295+
transitions back to waiting.
2296+
2297+
See also
2298+
--------
2299+
_transition_executing_long_running
2300+
_transition_cancelled_waiting
2301+
_transition_resumed_waiting
2302+
"""
2303+
assert ts.previous in ("executing", "long-running")
2304+
ts.previous = "long-running"
2305+
self.executing.discard(ts)
2306+
self.long_running.add(ts)
2307+
return self._ensure_computing()
2308+
22492309
def _transition_executing_memory(
22502310
self, ts: TaskState, value: object, *, stimulus_id: str
22512311
) -> RecsInstrs:
@@ -2352,15 +2412,16 @@ def _transition_released_forgotten(
23522412
] = {
23532413
("cancelled", "error"): _transition_cancelled_released,
23542414
("cancelled", "fetch"): _transition_cancelled_fetch,
2415+
("cancelled", "long-running"): _transition_cancelled_or_resumed_long_running,
23552416
("cancelled", "memory"): _transition_cancelled_released,
23562417
("cancelled", "missing"): _transition_cancelled_released,
23572418
("cancelled", "released"): _transition_cancelled_released,
23582419
("cancelled", "rescheduled"): _transition_cancelled_released,
23592420
("cancelled", "waiting"): _transition_cancelled_waiting,
23602421
("resumed", "error"): _transition_resumed_error,
23612422
("resumed", "fetch"): _transition_resumed_fetch,
2423+
("resumed", "long-running"): _transition_cancelled_or_resumed_long_running,
23622424
("resumed", "memory"): _transition_resumed_memory,
2363-
("resumed", "missing"): _transition_resumed_missing,
23642425
("resumed", "released"): _transition_resumed_released,
23652426
("resumed", "rescheduled"): _transition_resumed_rescheduled,
23662427
("resumed", "waiting"): _transition_resumed_waiting,
@@ -2898,10 +2959,9 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
28982959
@_handle_event.register
28992960
def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs:
29002961
ts = self.tasks.get(ev.key)
2901-
if ts and ts.state == "executing":
2902-
return {ts: ("long-running", ev.compute_duration)}, []
2903-
else:
2962+
if not ts:
29042963
return {}, []
2964+
return {ts: ("long-running", ev.compute_duration)}, []
29052965

29062966
@_handle_event.register
29072967
def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs:

0 commit comments

Comments
 (0)