Skip to content

Commit ade4266

Browse files
Add WorkerState.all_running_tasks (#6690)
1 parent bec621e commit ade4266

2 files changed

Lines changed: 77 additions & 6 deletions

File tree

distributed/tests/test_worker_state_machine.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ComputeTaskEvent,
2929
ExecuteFailureEvent,
3030
ExecuteSuccessEvent,
31+
FreeKeysEvent,
3132
GatherDep,
3233
Instruction,
3334
PauseEvent,
@@ -1037,3 +1038,58 @@ async def test_clean_log(s, a, b):
10371038
"""Test that brand new workers start with a clean log"""
10381039
assert not a.state.log
10391040
assert not a.state.stimulus_log
1041+
1042+
1043+
def test_running_task_in_all_running_tasks(ws_with_running_task):
1044+
ws = ws_with_running_task
1045+
ws2 = "127.0.0.1:2"
1046+
ts = ws.tasks["x"]
1047+
assert ts in ws.all_running_tasks
1048+
1049+
ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
1050+
assert ts.state == "cancelled"
1051+
assert ts in ws.all_running_tasks
1052+
1053+
ws.handle_stimulus(
1054+
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2")
1055+
)
1056+
assert ts.state == "resumed"
1057+
assert ts in ws.all_running_tasks
1058+
1059+
1060+
@pytest.mark.xfail(reason="distributed#6565, distributed#6692")
1061+
@pytest.mark.parametrize(
1062+
"done_ev_cls,done_status",
1063+
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
1064+
)
1065+
def test_done_task_not_in_all_running_tasks(
1066+
ws_with_running_task, done_ev_cls, done_status
1067+
):
1068+
ws = ws_with_running_task
1069+
ts = ws.tasks["x"]
1070+
assert ts in ws.all_running_tasks
1071+
1072+
ws.handle_stimulus(done_ev_cls.dummy("x", stimulus_id="s1"))
1073+
assert ts.state == done_status
1074+
assert ts not in ws.all_running_tasks
1075+
1076+
1077+
@pytest.mark.xfail(reason="distributed#6565, distributed#6689, distributed#6692")
1078+
@pytest.mark.parametrize(
1079+
"done_ev_cls,done_status",
1080+
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
1081+
)
1082+
def test_done_resumed_task_not_in_all_running_tasks(
1083+
ws_with_running_task, done_ev_cls, done_status
1084+
):
1085+
ws = ws_with_running_task
1086+
ws2 = "127.0.0.1:2"
1087+
1088+
ws.handle_stimulus(
1089+
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
1090+
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s2"),
1091+
done_ev_cls.dummy("x", stimulus_id="s3"),
1092+
)
1093+
ts = ws.tasks["x"]
1094+
assert ts.state == done_status
1095+
assert ts not in ws.all_running_tasks

distributed/worker_state_machine.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,7 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> Instructions:
12031203
@property
12041204
def executing_count(self) -> int:
12051205
"""Count of tasks currently executing on this worker.
1206+
Does not include long running (a.k.a. seceded) and cancelled tasks.
12061207
12071208
See also
12081209
--------
@@ -1212,6 +1213,17 @@ def executing_count(self) -> int:
12121213
"""
12131214
return len(self.executing)
12141215

1216+
@property
1217+
def all_running_tasks(self) -> set[TaskState]:
1218+
"""All tasks that are currently occupying a thread.
1219+
These are:
1220+
1221+
- ``ts.status in ("executing", "long-running", "cancelled")``
1222+
- ``ts.status == "resumed" and ts._previous in ("executing", "long-running")``
1223+
"""
1224+
# Note: cancelled and resumed tasks are still in either of these sets
1225+
return self.executing | {self.tasks[key] for key in self.long_running}
1226+
12151227
@property
12161228
def in_flight_tasks_count(self) -> int:
12171229
"""Count of tasks currently being replicated from other workers to this one.
@@ -1981,7 +1993,7 @@ def _transition_cancelled_fetch(
19811993
ts.state = ts._previous
19821994
return {}, []
19831995
else:
1984-
assert ts._previous == "executing"
1996+
assert ts._previous in {"executing", "long-running"}
19851997
ts.state = "resumed"
19861998
ts._next = "fetch"
19871999
return {}, []
@@ -3119,11 +3131,14 @@ def validate_state(self) -> None:
31193131
waiting_for_data_count += 1
31203132
for ts_wait in ts.waiting_for_data:
31213133
assert ts_wait.key in self.tasks
3122-
assert (
3123-
ts_wait.state in READY | {"executing", "flight", "fetch", "missing"}
3124-
or ts_wait in self.missing_dep_flight
3125-
or ts_wait.who_has.issubset(self.in_flight_workers)
3126-
), (ts, ts_wait, self.story(ts), self.story(ts_wait))
3134+
assert ts_wait.state in READY | {
3135+
"executing",
3136+
"long-running",
3137+
"resumed",
3138+
"flight",
3139+
"fetch",
3140+
"missing",
3141+
}, (ts, ts_wait, self.story(ts), self.story(ts_wait))
31273142
# FIXME https://github.com/dask/distributed/issues/6319
31283143
# assert self.waiting_for_data_count == waiting_for_data_count
31293144
for worker, keys in self.has_what.items():

0 commit comments

Comments
 (0)