@@ -517,7 +517,7 @@ class LongRunningMsg(SendMessageToScheduler):
517517
518518 __slots__ = ("key" , "compute_duration" )
519519 key : str
520- compute_duration : float
520+ compute_duration : float | None
521521
522522
523523@dataclass
@@ -2077,24 +2077,39 @@ def _transition_resumed_waiting(
20772077 See also
20782078 --------
20792079 _transition_cancelled_fetch
2080+ _transition_cancelled_or_resumed_long_running
20802081 _transition_cancelled_waiting
20812082 _transition_resumed_fetch
20822083 """
20832084 # None of the exit events of execute or gather_dep recommend a transition to
20842085 # waiting
20852086 assert not ts .done
2086- if ts .previous in ( "executing" , "long-running" ) :
2087+ if ts .previous == "executing" :
20872088 assert ts .next == "fetch"
20882089 # We're back where we started. We should forget about the entire
20892090 # cancellation attempt
2090- ts .state = ts . previous
2091+ ts .state = "executing"
20912092 ts .next = None
20922093 ts .previous = None
2093- elif self .validate :
2094+ return {}, []
2095+
2096+ elif ts .previous == "long-running" :
2097+ assert ts .next == "fetch"
2098+ # Same as executing, and in addition send the LongRunningMsg in arrears
2099+ # Note that, if the task seceded before it was cancelled, this will cause
2100+ # the message to be sent twice.
2101+ ts .state = "long-running"
2102+ ts .next = None
2103+ ts .previous = None
2104+ smsg = LongRunningMsg (
2105+ key = ts .key , compute_duration = None , stimulus_id = stimulus_id
2106+ )
2107+ return {}, [smsg ]
2108+
2109+ else :
20942110 assert ts .previous == "flight"
20952111 assert ts .next == "waiting"
2096-
2097- return {}, []
2112+ return {}, []
20982113
20992114 def _transition_cancelled_fetch (
21002115 self , ts : TaskState , * , stimulus_id : str
@@ -2131,17 +2146,29 @@ def _transition_cancelled_waiting(
21312146 See also
21322147 --------
21332148 _transition_cancelled_fetch
2149+ _transition_cancelled_or_resumed_long_running
21342150 _transition_resumed_fetch
21352151 _transition_resumed_waiting
21362152 """
21372153 # None of the exit events of gather_dep or execute recommend a transition to
21382154 # waiting
21392155 assert not ts .done
2140- if ts .previous in ( "executing" , "long-running" ) :
2156+ if ts .previous == "executing" :
21412157 # Forget the task was cancelled to begin with
2142- ts .state = ts . previous
2158+ ts .state = "executing"
21432159 ts .previous = None
21442160 return {}, []
2161+ elif ts .previous == "long-running" :
2162+ # Forget the task was cancelled to begin with, and inform the scheduler
2163+ # in arrears that it has seceded.
2164+ # Note that, if the task seceded before it was cancelled, this will cause
2165+ # the message to be sent twice.
2166+ ts .state = "long-running"
2167+ ts .previous = None
2168+ smsg = LongRunningMsg (
2169+ key = ts .key , compute_duration = None , stimulus_id = stimulus_id
2170+ )
2171+ return {}, [smsg ]
21452172 else :
21462173 assert ts .previous == "flight"
21472174 ts .state = "resumed"
@@ -2234,6 +2261,11 @@ def _transition_flight_released(
22342261 def _transition_executing_long_running (
22352262 self , ts : TaskState , compute_duration : float , * , stimulus_id : str
22362263 ) -> RecsInstrs :
2264+ """
2265+ See also
2266+ --------
2267+ _transition_cancelled_or_resumed_long_running
2268+ """
22372269 ts .state = "long-running"
22382270 self .executing .discard (ts )
22392271 self .long_running .add (ts )
@@ -2246,6 +2278,34 @@ def _transition_executing_long_running(
22462278 self ._ensure_computing (),
22472279 )
22482280
2281+ def _transition_cancelled_or_resumed_long_running (
2282+ self , ts : TaskState , compute_duration : float , * , stimulus_id : str
2283+ ) -> RecsInstrs :
2284+ """Handles transitions:
2285+
2286+ - cancelled(executing) -> long-running
2287+ - cancelled(long-running) -> long-running (user called secede() twice)
2288+ - resumed(executing->fetch) -> long-running
2289+ - resumed(long-running->fetch) -> long-running (user called secede() twice)
2290+
2291+ Unlike in the executing->long_running transition, do not send LongRunningMsg.
2292+ From the scheduler's perspective, this task no longer exists (cancelled) or is
2293+ in memory on another worker (resumed). So it shouldn't hear about it.
2294+ Instead, we're going to send the LongRunningMsg when and if the task
2295+ transitions back to waiting.
2296+
2297+ See also
2298+ --------
2299+ _transition_executing_long_running
2300+ _transition_cancelled_waiting
2301+ _transition_resumed_waiting
2302+ """
2303+ assert ts .previous in ("executing" , "long-running" )
2304+ ts .previous = "long-running"
2305+ self .executing .discard (ts )
2306+ self .long_running .add (ts )
2307+ return self ._ensure_computing ()
2308+
22492309 def _transition_executing_memory (
22502310 self , ts : TaskState , value : object , * , stimulus_id : str
22512311 ) -> RecsInstrs :
@@ -2352,15 +2412,16 @@ def _transition_released_forgotten(
23522412 ] = {
23532413 ("cancelled" , "error" ): _transition_cancelled_released ,
23542414 ("cancelled" , "fetch" ): _transition_cancelled_fetch ,
2415+ ("cancelled" , "long-running" ): _transition_cancelled_or_resumed_long_running ,
23552416 ("cancelled" , "memory" ): _transition_cancelled_released ,
23562417 ("cancelled" , "missing" ): _transition_cancelled_released ,
23572418 ("cancelled" , "released" ): _transition_cancelled_released ,
23582419 ("cancelled" , "rescheduled" ): _transition_cancelled_released ,
23592420 ("cancelled" , "waiting" ): _transition_cancelled_waiting ,
23602421 ("resumed" , "error" ): _transition_resumed_error ,
23612422 ("resumed" , "fetch" ): _transition_resumed_fetch ,
2423+ ("resumed" , "long-running" ): _transition_cancelled_or_resumed_long_running ,
23622424 ("resumed" , "memory" ): _transition_resumed_memory ,
2363- ("resumed" , "missing" ): _transition_resumed_missing ,
23642425 ("resumed" , "released" ): _transition_resumed_released ,
23652426 ("resumed" , "rescheduled" ): _transition_resumed_rescheduled ,
23662427 ("resumed" , "waiting" ): _transition_resumed_waiting ,
@@ -2898,10 +2959,9 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
28982959 @_handle_event .register
28992960 def _handle_secede (self , ev : SecedeEvent ) -> RecsInstrs :
29002961 ts = self .tasks .get (ev .key )
2901- if ts and ts .state == "executing" :
2902- return {ts : ("long-running" , ev .compute_duration )}, []
2903- else :
2962+ if not ts :
29042963 return {}, []
2964+ return {ts : ("long-running" , ev .compute_duration )}, []
29052965
29062966 @_handle_event .register
29072967 def _handle_steal_request (self , ev : StealRequestEvent ) -> RecsInstrs :
0 commit comments