Skip to content

Commit baae70c

Browse files
authored
Automatically reschedule stalled queued tasks in CeleryExecutor (v2) (#23690)
Celery can lose tasks on worker shutdown, causing airflow to just wait on them indefinitely (may be related to celery/celery#7266). This PR expands the "stalled tasks" functionality which is already in place for adopted tasks, and adds the ability to apply it to all tasks such that these lost/hung tasks can be automatically recovered and queued up again.
1 parent 888bc2e commit baae70c

4 files changed

Lines changed: 245 additions & 75 deletions

File tree

airflow/config_templates/config.yml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,12 +1768,23 @@
17681768
default: "True"
17691769
- name: task_adoption_timeout
17701770
description: |
1771-
Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
1772-
stalled tasks.
1771+
Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled,
1772+
and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but
1773+
applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting
1774+
also applies to adopted tasks.
17731775
version_added: 2.0.0
17741776
type: integer
17751777
example: ~
17761778
default: "600"
1779+
- name: stalled_task_timeout
1780+
description: |
1781+
Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically
1782+
rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified.
1783+
When set to 0, automatic clearing of stalled tasks is disabled.
1784+
version_added: 2.3.1
1785+
type: integer
1786+
example: ~
1787+
default: "0"
17771788
- name: task_publish_max_retries
17781789
description: |
17791790
The Maximum number of retries for publishing task messages to the broker when failing

airflow/config_templates/default_airflow.cfg

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,10 +888,17 @@ operation_timeout = 1.0
888888
# or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob.
889889
task_track_started = True
890890

891-
# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
892-
# stalled tasks.
891+
# Time in seconds after which adopted tasks which are queued in celery are assumed to be stalled,
892+
# and are automatically rescheduled. This setting does the same thing as ``stalled_task_timeout`` but
893+
# applies specifically to adopted tasks only. When set to 0, the ``stalled_task_timeout`` setting
894+
# also applies to adopted tasks.
893895
task_adoption_timeout = 600
894896

897+
# Time in seconds after which tasks queued in celery are assumed to be stalled, and are automatically
898+
# rescheduled. Adopted tasks will instead use the ``task_adoption_timeout`` setting if specified.
899+
# When set to 0, automatic clearing of stalled tasks is disabled.
900+
stalled_task_timeout = 0
901+
895902
# The Maximum number of retries for publishing task messages to the broker when failing
896903
# due to ``AirflowTaskTimeout`` error before giving up and marking Task as failed.
897904
task_publish_max_retries = 3

airflow/executors/celery_executor.py

Lines changed: 112 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
import subprocess
3030
import time
3131
import traceback
32-
from collections import Counter, OrderedDict
32+
from collections import Counter
3333
from concurrent.futures import ProcessPoolExecutor
34+
from enum import Enum
3435
from multiprocessing import cpu_count
3536
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
3637

@@ -40,6 +41,7 @@
4041
from celery.result import AsyncResult
4142
from celery.signals import import_modules as celery_import_modules
4243
from setproctitle import setproctitle
44+
from sqlalchemy.orm.session import Session
4345

4446
import airflow.settings as settings
4547
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
@@ -50,6 +52,7 @@
5052
from airflow.stats import Stats
5153
from airflow.utils.log.logging_mixin import LoggingMixin
5254
from airflow.utils.net import get_hostname
55+
from airflow.utils.session import NEW_SESSION, provide_session
5356
from airflow.utils.state import State
5457
from airflow.utils.timeout import timeout
5558
from airflow.utils.timezone import utcnow
@@ -207,6 +210,11 @@ def on_celery_import_modules(*args, **kwargs):
207210
pass
208211

209212

213+
class _CeleryPendingTaskTimeoutType(Enum):
214+
ADOPTED = 1
215+
STALLED = 2
216+
217+
210218
class CeleryExecutor(BaseExecutor):
211219
"""
212220
CeleryExecutor is recommended for production use of Airflow. It allows
@@ -230,10 +238,14 @@ def __init__(self):
230238
self._sync_parallelism = max(1, cpu_count() - 1)
231239
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
232240
self.tasks = {}
233-
# Mapping of tasks we've adopted, ordered by the earliest date they timeout
234-
self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = OrderedDict()
235-
self.task_adoption_timeout = datetime.timedelta(
236-
seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)
241+
self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
242+
self.stalled_task_timeout = datetime.timedelta(
243+
seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0)
244+
)
245+
self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
246+
self.task_adoption_timeout = (
247+
datetime.timedelta(seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600))
248+
or self.stalled_task_timeout
237249
)
238250
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
239251
self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3)
@@ -285,6 +297,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
285297
result.backend = cached_celery_backend
286298
self.running.add(key)
287299
self.tasks[key] = result
300+
self._set_celery_pending_task_timeout(key, _CeleryPendingTaskTimeoutType.STALLED)
288301

289302
# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
290303
# has another event, but that is fine, because the only other events are success/failed at
@@ -315,25 +328,47 @@ def sync(self) -> None:
315328
self.log.debug("No task to query celery, skipping sync")
316329
return
317330
self.update_all_task_states()
331+
self._check_for_timedout_adopted_tasks()
332+
self._check_for_stalled_tasks()
333+
334+
def _check_for_timedout_adopted_tasks(self) -> None:
335+
timedout_keys = self._get_timedout_ti_keys(self.adopted_task_timeouts)
336+
if timedout_keys:
337+
self.log.error(
338+
"Adopted tasks were still pending after %s, assuming they never made it to celery "
339+
"and sending back to the scheduler:\n\t%s",
340+
self.task_adoption_timeout,
341+
"\n\t".join(repr(x) for x in timedout_keys),
342+
)
343+
self._send_stalled_tis_back_to_scheduler(timedout_keys)
318344

319-
if self.adopted_task_timeouts:
320-
self._check_for_stalled_adopted_tasks()
345+
def _check_for_stalled_tasks(self) -> None:
346+
timedout_keys = self._get_timedout_ti_keys(self.stalled_task_timeouts)
347+
if timedout_keys:
348+
self.log.error(
349+
"Tasks were still pending after %s, assuming they never made it to celery "
350+
"and sending back to the scheduler:\n\t%s",
351+
self.stalled_task_timeout,
352+
"\n\t".join(repr(x) for x in timedout_keys),
353+
)
354+
self._send_stalled_tis_back_to_scheduler(timedout_keys)
321355

322-
def _check_for_stalled_adopted_tasks(self):
356+
def _get_timedout_ti_keys(
357+
self, task_timeouts: Dict[TaskInstanceKey, datetime.datetime]
358+
) -> List[TaskInstanceKey]:
323359
"""
324-
See if any of the tasks we adopted from another Executor run have not
325-
progressed after the configured timeout.
326-
327-
If they haven't, they likely never made it to Celery, and we should
328-
just resend them. We do that by clearing the state and letting the
329-
normal scheduler loop deal with that
360+
These timeouts exist to check to see if any of our tasks have not progressed
361+
in the expected time. This can happen for few different reasons, usually related
362+
to race conditions while shutting down schedulers and celery workers.
363+
364+
It is, of course, always possible that these tasks are not actually
365+
stalled - they could just be waiting in a long celery queue.
366+
Unfortunately there's no way for us to know for sure, so we'll just
367+
reschedule them and let the normal scheduler loop requeue them.
330368
"""
331369
now = utcnow()
332-
333-
sorted_adopted_task_timeouts = sorted(self.adopted_task_timeouts.items(), key=lambda k: k[1])
334-
335370
timedout_keys = []
336-
for key, stalled_after in sorted_adopted_task_timeouts:
371+
for key, stalled_after in task_timeouts.items():
337372
if stalled_after > now:
338373
# Since items are stored sorted, if we get to a stalled_after
339374
# in the future then we can stop
@@ -343,20 +378,46 @@ def _check_for_stalled_adopted_tasks(self):
343378
# already finished, then it will be removed from this list -- so
344379
# the only time it's still in this list is when it a) never made it
345380
# to celery in the first place (i.e. race condition somewhere in
346-
# the dying executor) or b) a really long celery queue and it just
381+
# the dying executor), b) celery lost the task before execution
382+
# started, or c) a really long celery queue and it just
347383
# hasn't started yet -- better cancel it and let the scheduler
348384
# re-queue rather than have this task risk stalling for ever
349385
timedout_keys.append(key)
386+
return timedout_keys
350387

351-
if timedout_keys:
352-
self.log.error(
353-
"Adopted tasks were still pending after %s, assuming they never made it to celery and "
354-
"clearing:\n\t%s",
355-
self.task_adoption_timeout,
356-
"\n\t".join(repr(x) for x in timedout_keys),
388+
@provide_session
389+
def _send_stalled_tis_back_to_scheduler(
390+
self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION
391+
) -> None:
392+
try:
393+
session.query(TaskInstance).filter(
394+
TaskInstance.filter_for_tis(keys),
395+
TaskInstance.state == State.QUEUED,
396+
TaskInstance.queued_by_job_id == self.job_id,
397+
).update(
398+
{
399+
TaskInstance.state: State.SCHEDULED,
400+
TaskInstance.queued_dttm: None,
401+
TaskInstance.queued_by_job_id: None,
402+
TaskInstance.external_executor_id: None,
403+
},
404+
synchronize_session=False,
357405
)
358-
for key in timedout_keys:
359-
self.change_state(key, State.FAILED)
406+
session.commit()
407+
except Exception:
408+
self.log.exception("Error sending tasks back to scheduler")
409+
session.rollback()
410+
return
411+
412+
for key in keys:
413+
self._set_celery_pending_task_timeout(key, None)
414+
self.running.discard(key)
415+
celery_async_result = self.tasks.pop(key, None)
416+
if celery_async_result:
417+
try:
418+
app.control.revoke(celery_async_result.task_id)
419+
except Exception as ex:
420+
self.log.error("Error revoking task instance %s from celery: %s", key, ex)
360421

361422
def debug_dump(self) -> None:
362423
"""Called in response to SIGUSR2 by the scheduler"""
@@ -369,6 +430,11 @@ def debug_dump(self) -> None:
369430
len(self.adopted_task_timeouts),
370431
"\n\t".join(map(repr, self.adopted_task_timeouts.items())),
371432
)
433+
self.log.info(
434+
"executor.stalled_task_timeouts (%d)\n\t%s",
435+
len(self.stalled_task_timeouts),
436+
"\n\t".join(map(repr, self.stalled_task_timeouts.items())),
437+
)
372438

373439
def update_all_task_states(self) -> None:
374440
"""Updates states of the tasks."""
@@ -384,7 +450,7 @@ def update_all_task_states(self) -> None:
384450
def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
385451
super().change_state(key, state, info)
386452
self.tasks.pop(key, None)
387-
self.adopted_task_timeouts.pop(key, None)
453+
self._set_celery_pending_task_timeout(key, None)
388454

389455
def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
390456
"""Updates state of a single task."""
@@ -394,8 +460,8 @@ def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None
394460
elif state in (celery_states.FAILURE, celery_states.REVOKED):
395461
self.fail(key, info)
396462
elif state == celery_states.STARTED:
397-
# It's now actually running, so know it made it to celery okay!
398-
self.adopted_task_timeouts.pop(key, None)
463+
# It's now actually running, so we know it made it to celery okay!
464+
self._set_celery_pending_task_timeout(key, None)
399465
elif state == celery_states.PENDING:
400466
pass
401467
else:
@@ -455,7 +521,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
455521

456522
# Set the correct elements of the state dicts, then update this
457523
# like we just queried it.
458-
self.adopted_task_timeouts[ti.key] = ti.queued_dttm + self.task_adoption_timeout
524+
self._set_celery_pending_task_timeout(ti.key, _CeleryPendingTaskTimeoutType.ADOPTED)
459525
self.tasks[ti.key] = result
460526
self.running.add(ti.key)
461527
self.update_task_state(ti.key, state, info)
@@ -469,6 +535,21 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
469535

470536
return not_adopted_tis
471537

538+
def _set_celery_pending_task_timeout(
539+
self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType]
540+
) -> None:
541+
"""
542+
We use the fact that dicts maintain insertion order, and the the timeout for a
543+
task is always "now + delta" to maintain the property that oldest item = first to
544+
time out.
545+
"""
546+
self.adopted_task_timeouts.pop(key, None)
547+
self.stalled_task_timeouts.pop(key, None)
548+
if timeout_type == _CeleryPendingTaskTimeoutType.ADOPTED and self.task_adoption_timeout:
549+
self.adopted_task_timeouts[key] = utcnow() + self.task_adoption_timeout
550+
elif timeout_type == _CeleryPendingTaskTimeoutType.STALLED and self.stalled_task_timeout:
551+
self.stalled_task_timeouts[key] = utcnow() + self.stalled_task_timeout
552+
472553

473554
def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]:
474555
"""

0 commit comments

Comments
 (0)