2929import subprocess
3030import time
3131import traceback
32- from collections import Counter , OrderedDict
32+ from collections import Counter
3333from concurrent .futures import ProcessPoolExecutor
34+ from enum import Enum
3435from multiprocessing import cpu_count
3536from typing import Any , Dict , List , Mapping , MutableMapping , Optional , Set , Tuple , Union
3637
4041from celery .result import AsyncResult
4142from celery .signals import import_modules as celery_import_modules
4243from setproctitle import setproctitle
44+ from sqlalchemy .orm .session import Session
4345
4446import airflow .settings as settings
4547from airflow .config_templates .default_celery import DEFAULT_CELERY_CONFIG
5052from airflow .stats import Stats
5153from airflow .utils .log .logging_mixin import LoggingMixin
5254from airflow .utils .net import get_hostname
55+ from airflow .utils .session import NEW_SESSION , provide_session
5356from airflow .utils .state import State
5457from airflow .utils .timeout import timeout
5558from airflow .utils .timezone import utcnow
@@ -207,6 +210,11 @@ def on_celery_import_modules(*args, **kwargs):
207210 pass
208211
209212
213+ class _CeleryPendingTaskTimeoutType (Enum ):
214+ ADOPTED = 1
215+ STALLED = 2
216+
217+
210218class CeleryExecutor (BaseExecutor ):
211219 """
212220 CeleryExecutor is recommended for production use of Airflow. It allows
@@ -230,10 +238,14 @@ def __init__(self):
230238 self ._sync_parallelism = max (1 , cpu_count () - 1 )
231239 self .bulk_state_fetcher = BulkStateFetcher (self ._sync_parallelism )
232240 self .tasks = {}
233- # Mapping of tasks we've adopted, ordered by the earliest date they timeout
234- self .adopted_task_timeouts : Dict [TaskInstanceKey , datetime .datetime ] = OrderedDict ()
235- self .task_adoption_timeout = datetime .timedelta (
236- seconds = conf .getint ('celery' , 'task_adoption_timeout' , fallback = 600 )
241+ self .stalled_task_timeouts : Dict [TaskInstanceKey , datetime .datetime ] = {}
242+ self .stalled_task_timeout = datetime .timedelta (
243+ seconds = conf .getint ('celery' , 'stalled_task_timeout' , fallback = 0 )
244+ )
245+ self .adopted_task_timeouts : Dict [TaskInstanceKey , datetime .datetime ] = {}
246+ self .task_adoption_timeout = (
247+ datetime .timedelta (seconds = conf .getint ('celery' , 'task_adoption_timeout' , fallback = 600 ))
248+ or self .stalled_task_timeout
237249 )
238250 self .task_publish_retries : Counter [TaskInstanceKey ] = Counter ()
239251 self .task_publish_max_retries = conf .getint ('celery' , 'task_publish_max_retries' , fallback = 3 )
@@ -285,6 +297,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
285297 result .backend = cached_celery_backend
286298 self .running .add (key )
287299 self .tasks [key ] = result
300+ self ._set_celery_pending_task_timeout (key , _CeleryPendingTaskTimeoutType .STALLED )
288301
289302 # Store the Celery task_id in the event buffer. This will get "overwritten" if the task
290303 # has another event, but that is fine, because the only other events are success/failed at
@@ -315,25 +328,47 @@ def sync(self) -> None:
315328 self .log .debug ("No task to query celery, skipping sync" )
316329 return
317330 self .update_all_task_states ()
331+ self ._check_for_timedout_adopted_tasks ()
332+ self ._check_for_stalled_tasks ()
333+
334+ def _check_for_timedout_adopted_tasks (self ) -> None :
335+ timedout_keys = self ._get_timedout_ti_keys (self .adopted_task_timeouts )
336+ if timedout_keys :
337+ self .log .error (
338+ "Adopted tasks were still pending after %s, assuming they never made it to celery "
339+ "and sending back to the scheduler:\n \t %s" ,
340+ self .task_adoption_timeout ,
341+ "\n \t " .join (repr (x ) for x in timedout_keys ),
342+ )
343+ self ._send_stalled_tis_back_to_scheduler (timedout_keys )
318344
319- if self .adopted_task_timeouts :
320- self ._check_for_stalled_adopted_tasks ()
345+ def _check_for_stalled_tasks (self ) -> None :
346+ timedout_keys = self ._get_timedout_ti_keys (self .stalled_task_timeouts )
347+ if timedout_keys :
348+ self .log .error (
349+ "Tasks were still pending after %s, assuming they never made it to celery "
350+ "and sending back to the scheduler:\n \t %s" ,
351+ self .stalled_task_timeout ,
352+ "\n \t " .join (repr (x ) for x in timedout_keys ),
353+ )
354+ self ._send_stalled_tis_back_to_scheduler (timedout_keys )
321355
322- def _check_for_stalled_adopted_tasks (self ):
356+ def _get_timedout_ti_keys (
357+ self , task_timeouts : Dict [TaskInstanceKey , datetime .datetime ]
358+ ) -> List [TaskInstanceKey ]:
323359 """
324- See if any of the tasks we adopted from another Executor run have not
325- progressed after the configured timeout.
326-
327- If they haven't, they likely never made it to Celery, and we should
328- just resend them. We do that by clearing the state and letting the
329- normal scheduler loop deal with that
360+ These timeouts exist to check to see if any of our tasks have not progressed
361+ in the expected time. This can happen for few different reasons, usually related
362+ to race conditions while shutting down schedulers and celery workers.
363+
364+ It is, of course, always possible that these tasks are not actually
365+ stalled - they could just be waiting in a long celery queue.
366+ Unfortunately there's no way for us to know for sure, so we'll just
367+ reschedule them and let the normal scheduler loop requeue them.
330368 """
331369 now = utcnow ()
332-
333- sorted_adopted_task_timeouts = sorted (self .adopted_task_timeouts .items (), key = lambda k : k [1 ])
334-
335370 timedout_keys = []
336- for key , stalled_after in sorted_adopted_task_timeouts :
371+ for key , stalled_after in task_timeouts . items () :
337372 if stalled_after > now :
338373 # Since items are stored sorted, if we get to a stalled_after
339374 # in the future then we can stop
@@ -343,20 +378,46 @@ def _check_for_stalled_adopted_tasks(self):
343378 # already finished, then it will be removed from this list -- so
344379 # the only time it's still in this list is when it a) never made it
345380 # to celery in the first place (i.e. race condition somewhere in
346- # the dying executor) or b) a really long celery queue and it just
381+ # the dying executor), b) celery lost the task before execution
382+ # started, or c) a really long celery queue and it just
347383 # hasn't started yet -- better cancel it and let the scheduler
348384 # re-queue rather than have this task risk stalling for ever
349385 timedout_keys .append (key )
386+ return timedout_keys
350387
351- if timedout_keys :
352- self .log .error (
353- "Adopted tasks were still pending after %s, assuming they never made it to celery and "
354- "clearing:\n \t %s" ,
355- self .task_adoption_timeout ,
356- "\n \t " .join (repr (x ) for x in timedout_keys ),
388+ @provide_session
389+ def _send_stalled_tis_back_to_scheduler (
390+ self , keys : List [TaskInstanceKey ], session : Session = NEW_SESSION
391+ ) -> None :
392+ try :
393+ session .query (TaskInstance ).filter (
394+ TaskInstance .filter_for_tis (keys ),
395+ TaskInstance .state == State .QUEUED ,
396+ TaskInstance .queued_by_job_id == self .job_id ,
397+ ).update (
398+ {
399+ TaskInstance .state : State .SCHEDULED ,
400+ TaskInstance .queued_dttm : None ,
401+ TaskInstance .queued_by_job_id : None ,
402+ TaskInstance .external_executor_id : None ,
403+ },
404+ synchronize_session = False ,
357405 )
358- for key in timedout_keys :
359- self .change_state (key , State .FAILED )
406+ session .commit ()
407+ except Exception :
408+ self .log .exception ("Error sending tasks back to scheduler" )
409+ session .rollback ()
410+ return
411+
412+ for key in keys :
413+ self ._set_celery_pending_task_timeout (key , None )
414+ self .running .discard (key )
415+ celery_async_result = self .tasks .pop (key , None )
416+ if celery_async_result :
417+ try :
418+ app .control .revoke (celery_async_result .task_id )
419+ except Exception as ex :
420+ self .log .error ("Error revoking task instance %s from celery: %s" , key , ex )
360421
361422 def debug_dump (self ) -> None :
362423 """Called in response to SIGUSR2 by the scheduler"""
@@ -369,6 +430,11 @@ def debug_dump(self) -> None:
369430 len (self .adopted_task_timeouts ),
370431 "\n \t " .join (map (repr , self .adopted_task_timeouts .items ())),
371432 )
433+ self .log .info (
434+ "executor.stalled_task_timeouts (%d)\n \t %s" ,
435+ len (self .stalled_task_timeouts ),
436+ "\n \t " .join (map (repr , self .stalled_task_timeouts .items ())),
437+ )
372438
373439 def update_all_task_states (self ) -> None :
374440 """Updates states of the tasks."""
@@ -384,7 +450,7 @@ def update_all_task_states(self) -> None:
384450 def change_state (self , key : TaskInstanceKey , state : str , info = None ) -> None :
385451 super ().change_state (key , state , info )
386452 self .tasks .pop (key , None )
387- self .adopted_task_timeouts . pop (key , None )
453+ self ._set_celery_pending_task_timeout (key , None )
388454
389455 def update_task_state (self , key : TaskInstanceKey , state : str , info : Any ) -> None :
390456 """Updates state of a single task."""
@@ -394,8 +460,8 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None
394460 elif state in (celery_states .FAILURE , celery_states .REVOKED ):
395461 self .fail (key , info )
396462 elif state == celery_states .STARTED :
397- # It's now actually running, so know it made it to celery okay!
398- self .adopted_task_timeouts . pop (key , None )
463+ # It's now actually running, so we know it made it to celery okay!
464+ self ._set_celery_pending_task_timeout (key , None )
399465 elif state == celery_states .PENDING :
400466 pass
401467 else :
@@ -455,7 +521,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
455521
456522 # Set the correct elements of the state dicts, then update this
457523 # like we just queried it.
458- self .adopted_task_timeouts [ ti .key ] = ti . queued_dttm + self . task_adoption_timeout
524+ self ._set_celery_pending_task_timeout ( ti .key , _CeleryPendingTaskTimeoutType . ADOPTED )
459525 self .tasks [ti .key ] = result
460526 self .running .add (ti .key )
461527 self .update_task_state (ti .key , state , info )
@@ -469,6 +535,21 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
469535
470536 return not_adopted_tis
471537
538+ def _set_celery_pending_task_timeout (
539+ self , key : TaskInstanceKey , timeout_type : Optional [_CeleryPendingTaskTimeoutType ]
540+ ) -> None :
541+ """
542+ We use the fact that dicts maintain insertion order, and the the timeout for a
543+ task is always "now + delta" to maintain the property that oldest item = first to
544+ time out.
545+ """
546+ self .adopted_task_timeouts .pop (key , None )
547+ self .stalled_task_timeouts .pop (key , None )
548+ if timeout_type == _CeleryPendingTaskTimeoutType .ADOPTED and self .task_adoption_timeout :
549+ self .adopted_task_timeouts [key ] = utcnow () + self .task_adoption_timeout
550+ elif timeout_type == _CeleryPendingTaskTimeoutType .STALLED and self .stalled_task_timeout :
551+ self .stalled_task_timeouts [key ] = utcnow () + self .stalled_task_timeout
552+
472553
473554def fetch_celery_task_state (async_result : AsyncResult ) -> Tuple [str , Union [str , ExceptionWithTraceback ], Any ]:
474555 """
0 commit comments