Skip to content

Commit 73fe489

Browse files
committed
fix: address review findings in shutdown strategies
1 parent 94813cb commit 73fe489

2 files changed

Lines changed: 69 additions & 61 deletions

File tree

src/synthorg/engine/shutdown.py

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ async def _wait_and_cancel(
242242
if exc is not None:
243243
logger.warning(
244244
EXECUTION_SHUTDOWN_TASK_ERROR,
245-
error=(f"Task raised during shutdown: {type(exc).__name__}"),
245+
error=f"Task raised during shutdown: {type(exc).__name__}: {exc}",
246246
)
247247
else:
248248
tasks_completed += 1
@@ -260,41 +260,10 @@ async def _wait_and_cancel(
260260
pending,
261261
timeout=self._CANCEL_PROPAGATION_TIMEOUT,
262262
)
263-
self._log_post_cancel_exceptions(cancel_done)
263+
_log_post_cancel_exceptions(cancel_done)
264264

265265
return tasks_completed, len(pending)
266266

267-
def _log_post_cancel_exceptions(
268-
self,
269-
tasks: set[asyncio.Task[Any]],
270-
) -> None:
271-
"""Retrieve and log exceptions from post-cancel tasks.
272-
273-
Retrieving the exception prevents asyncio's "Task exception was
274-
never retrieved" warning. Non-cancelled tasks with exceptions
275-
are logged at DEBUG.
276-
"""
277-
for task in tasks:
278-
if task.cancelled():
279-
continue
280-
try:
281-
exc = task.exception()
282-
except asyncio.InvalidStateError:
283-
logger.debug(
284-
EXECUTION_SHUTDOWN_TASK_ERROR,
285-
error="Failed to inspect post-cancel task: InvalidStateError",
286-
task_name=task.get_name(),
287-
)
288-
else:
289-
if exc is not None:
290-
logger.debug(
291-
EXECUTION_SHUTDOWN_TASK_ERROR,
292-
error=(
293-
f"Post-cancel task exception: {type(exc).__name__}: {exc}"
294-
),
295-
task_name=task.get_name(),
296-
)
297-
298267

299268
class ImmediateCancelStrategy:
300269
"""Immediate cancel shutdown strategy.
@@ -349,11 +318,7 @@ async def execute_shutdown(
349318
task_set,
350319
timeout=self._CANCEL_PROPAGATION_TIMEOUT,
351320
)
352-
# Retrieve exceptions to suppress "never retrieved" warnings.
353-
for task in cancel_done:
354-
if not task.cancelled():
355-
with contextlib.suppress(Exception):
356-
task.exception()
321+
_log_post_cancel_exceptions(cancel_done)
357322

358323
cleanup_completed = await _run_cleanup(cleanup_callbacks, self._cleanup_seconds)
359324

@@ -448,7 +413,7 @@ async def execute_shutdown(
448413
if exc is not None:
449414
logger.warning(
450415
EXECUTION_SHUTDOWN_TASK_ERROR,
451-
error=f"Task raised during shutdown: {type(exc).__name__}",
416+
error=f"Task raised during shutdown: {type(exc).__name__}: {exc}",
452417
)
453418
else:
454419
tasks_completed += 1
@@ -465,10 +430,7 @@ async def execute_shutdown(
465430
pending,
466431
timeout=self._CANCEL_PROPAGATION_TIMEOUT,
467432
)
468-
for task in cancel_done:
469-
if not task.cancelled():
470-
with contextlib.suppress(Exception):
471-
task.exception()
433+
_log_post_cancel_exceptions(cancel_done)
472434

473435
cleanup_completed = await _run_cleanup(cleanup_callbacks, self._cleanup_seconds)
474436

@@ -568,7 +530,7 @@ async def execute_shutdown(
568530
if exc is not None:
569531
logger.warning(
570532
EXECUTION_SHUTDOWN_TASK_ERROR,
571-
error=f"Task raised during shutdown: {type(exc).__name__}",
533+
error=f"Task raised during shutdown: {type(exc).__name__}: {exc}",
572534
)
573535
else:
574536
tasks_suspended += 1
@@ -580,6 +542,11 @@ async def execute_shutdown(
580542
tasks_interrupted = 0
581543
for task in pending:
582544
task_id = task_to_id.get(task, "unknown")
545+
if task_id == "unknown":
546+
logger.debug(
547+
EXECUTION_SHUTDOWN_TASK_ERROR,
548+
error="Task not found in reverse map during checkpoint",
549+
)
583550
saved = await self._try_checkpoint(task_id)
584551
if saved:
585552
tasks_suspended += 1
@@ -593,10 +560,7 @@ async def execute_shutdown(
593560
pending,
594561
timeout=self._CANCEL_PROPAGATION_TIMEOUT,
595562
)
596-
for task in cancel_done:
597-
if not task.cancelled():
598-
with contextlib.suppress(Exception):
599-
task.exception()
563+
_log_post_cancel_exceptions(cancel_done)
600564

601565
cleanup_completed = await _run_cleanup(cleanup_callbacks, self._cleanup_seconds)
602566

@@ -628,10 +592,11 @@ async def _try_checkpoint(self, task_id: str) -> bool:
628592
return False
629593
try:
630594
saved = await self._checkpoint_saver(task_id)
631-
except Exception:
595+
except Exception as exc:
632596
logger.exception(
633597
EXECUTION_SHUTDOWN_CHECKPOINT_FAILED,
634598
task_id=task_id,
599+
error_type=type(exc).__name__,
635600
)
636601
return False
637602
if saved:
@@ -648,7 +613,34 @@ async def _try_checkpoint(self, task_id: str) -> bool:
648613
return saved
649614

650615

651-
# ── Shared cleanup helper ────────────────────────────────────────
616+
# ── Shared helpers ───────────────────────────────────────────────
617+
618+
619+
def _log_post_cancel_exceptions(tasks: set[asyncio.Task[Any]]) -> None:
620+
"""Retrieve and log exceptions from post-cancel tasks.
621+
622+
Retrieving the exception prevents asyncio's "Task exception was
623+
never retrieved" warning. Non-cancelled tasks with exceptions
624+
are logged at DEBUG.
625+
"""
626+
for task in tasks:
627+
if task.cancelled():
628+
continue
629+
try:
630+
exc = task.exception()
631+
except asyncio.InvalidStateError:
632+
logger.debug(
633+
EXECUTION_SHUTDOWN_TASK_ERROR,
634+
error="Failed to inspect post-cancel task: InvalidStateError",
635+
task_name=task.get_name(),
636+
)
637+
else:
638+
if exc is not None:
639+
logger.debug(
640+
EXECUTION_SHUTDOWN_TASK_ERROR,
641+
error=(f"Post-cancel task exception: {type(exc).__name__}: {exc}"),
642+
task_name=task.get_name(),
643+
)
652644

653645

654646
async def _run_cleanup(

tests/unit/engine/test_shutdown.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ShutdownManager,
1919
ShutdownResult,
2020
ShutdownStrategy,
21+
_log_post_cancel_exceptions,
2122
build_shutdown_strategy,
2223
)
2324

@@ -424,40 +425,33 @@ class TestLogPostCancelExceptions:
424425
"""Extracted helper retrieves exceptions without swallowing them."""
425426

426427
def test_skips_cancelled_tasks(self) -> None:
427-
strategy = CooperativeTimeoutStrategy()
428428
task = MagicMock(spec=asyncio.Task)
429429
task.cancelled.return_value = True
430-
# Should not call task.exception()
431-
strategy._log_post_cancel_exceptions({task})
430+
_log_post_cancel_exceptions({task})
432431
task.exception.assert_not_called()
433432

434433
def test_logs_task_exception(self) -> None:
435-
strategy = CooperativeTimeoutStrategy()
436434
task = MagicMock(spec=asyncio.Task)
437435
task.cancelled.return_value = False
438436
task.exception.return_value = RuntimeError("boom")
439437
task.get_name.return_value = "test-task"
440-
# Should not raise
441-
strategy._log_post_cancel_exceptions({task})
438+
_log_post_cancel_exceptions({task})
442439
task.exception.assert_called_once()
443440

444441
def test_handles_no_exception(self) -> None:
445-
strategy = CooperativeTimeoutStrategy()
446442
task = MagicMock(spec=asyncio.Task)
447443
task.cancelled.return_value = False
448444
task.exception.return_value = None
449445
task.get_name.return_value = "test-task"
450-
strategy._log_post_cancel_exceptions({task})
446+
_log_post_cancel_exceptions({task})
451447
task.exception.assert_called_once()
452448

453449
def test_handles_invalid_state_error(self) -> None:
454-
strategy = CooperativeTimeoutStrategy()
455450
task = MagicMock(spec=asyncio.Task)
456451
task.cancelled.return_value = False
457452
task.exception.side_effect = asyncio.InvalidStateError
458453
task.get_name.return_value = "test-task"
459-
# Should not raise -- logs at DEBUG instead of silent pass
460-
strategy._log_post_cancel_exceptions({task})
454+
_log_post_cancel_exceptions({task})
461455

462456

463457
# ── Signal handler recovery ──────────────────────────────────────
@@ -827,6 +821,28 @@ async def stubborn() -> None:
827821
assert result.tasks_suspended == 0
828822
assert result.tasks_interrupted == 1
829823

824+
async def test_checkpoint_saver_raises_counted_as_interrupted(self) -> None:
825+
async def failing_saver(task_id: str) -> bool:
826+
msg = "Storage unavailable"
827+
raise OSError(msg)
828+
829+
strategy = CheckpointAndStopStrategy(
830+
grace_seconds=0.1,
831+
checkpoint_saver=failing_saver,
832+
)
833+
834+
async def stubborn() -> None:
835+
await asyncio.Event().wait()
836+
837+
task = asyncio.create_task(stubborn())
838+
result = await strategy.execute_shutdown(
839+
running_tasks={"t1": task},
840+
cleanup_callbacks=[],
841+
)
842+
843+
assert result.tasks_suspended == 0
844+
assert result.tasks_interrupted == 1
845+
830846
async def test_mixed_cooperative_and_straggler(self) -> None:
831847
saver_calls: list[str] = []
832848

0 commit comments

Comments
 (0)