Skip to content

Commit aecd46a

Browse files
chrisguidryclaude
andauthored
Add internal state invariant tests and reorganize test directories (#310)
Adds tests that verify Worker and StrikeList internal data structures are properly cleaned up after task completion and context exit. These tests would have caught the memory leak fixed in #309 where `_tasks_by_key` entries accumulated because we were using `task.get_name()` instead of `execution.key` for cleanup. Worker invariant tests check: - `_tasks_by_key` is empty after `run_until_finished()` - `_tasks_by_key` doesn't grow across multiple batches - `_execution_counts` is cleared after `run_at_most()` - Internal attributes are deleted after `__aexit__` - Cleanup works with failing tasks and varied task types - SharedContext ContextVars are properly reset StrikeList invariant tests check: - `_conditions` only contains default after removing temp conditions - No empty dicts remain in `task_strikes`/`parameter_strikes` after restore Also reorganizes tests into subdirectories: - `tests/worker/` for worker-related tests - `tests/instrumentation/` for metrics and tracing tests - Renames `test_synced_strikelist.py` to `test_strikelist.py` Adds `--import-mode=importlib` to pytest config to support same-named test files in different directories. 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c5a2b6d commit aecd46a

11 files changed

Lines changed: 357 additions & 1 deletion

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ addopts = [
114114
"--cov-report=term-missing:skip-covered",
115115
"--cov-branch",
116116
"--timeout=30",
117+
"--import-mode=importlib",
117118
]
118119
asyncio_mode = "auto"
119120
asyncio_default_fixture_loop_scope = "function"
Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
"""Tests for StrikeList."""
22

3+
from __future__ import annotations
4+
35
# pyright: reportPrivateUsage=false
46

57
import asyncio
6-
from typing import Any, Callable
8+
from typing import TYPE_CHECKING, Any, Callable
79

810
import pytest
911

1012
from docket import StrikeList
1113
from docket.strikelist import Operator, Restore, Strike
1214

15+
if TYPE_CHECKING:
16+
from docket.execution import Execution
17+
1318

1419
@pytest.fixture
1520
def strike_name(make_docket_name: Callable[[], str]) -> str:
@@ -312,3 +317,152 @@ async def test_type_mismatch_handled_gracefully(
312317
result = strikes.is_stricken({"amount": "not a number"})
313318
assert result is False
314319
assert "Incompatible type" in caplog.text
320+
321+
322+
# Internal state invariant tests
323+
324+
325+
async def test_invariant_conditions_only_default_after_remove(
326+
redis_url: str, strike_name: str
327+
):
328+
"""After removing a temporary condition, only the default condition should remain."""
329+
async with StrikeList(url=redis_url, name=strike_name) as strikes:
330+
# Initially only default condition
331+
assert len(strikes._conditions) == 1
332+
default_condition = strikes._conditions[0]
333+
334+
# Add a temporary condition (lambda to avoid coverage gap on unused function)
335+
temp_condition: Callable[[Execution], bool] = lambda _: False # noqa: E731
336+
337+
strikes.add_condition(temp_condition)
338+
assert len(strikes._conditions) == 2
339+
340+
# Remove the temporary condition
341+
strikes.remove_condition(temp_condition)
342+
343+
# Should be back to only the default
344+
assert len(strikes._conditions) == 1
345+
assert strikes._conditions[0] is default_condition
346+
347+
348+
async def test_invariant_no_empty_dicts_in_task_strikes_after_restore(
349+
redis_url: str, strike_name: str
350+
):
351+
"""After restoring all strikes for a task, no empty dicts should remain."""
352+
async with StrikeList(url=redis_url, name=strike_name) as strikes:
353+
# Strike a specific task+parameter combination
354+
await strikes.strike(
355+
function="my_task", parameter="user_id", operator="==", value=123
356+
)
357+
await asyncio.sleep(0.1)
358+
359+
# Verify structure exists
360+
assert "my_task" in strikes.task_strikes
361+
assert "user_id" in strikes.task_strikes["my_task"]
362+
363+
# Restore the strike
364+
await strikes.restore(
365+
function="my_task", parameter="user_id", operator="==", value=123
366+
)
367+
await asyncio.sleep(0.1)
368+
369+
# After restore, no empty dict entries should remain
370+
assert "my_task" not in strikes.task_strikes, (
371+
"task_strikes should not contain empty task entries"
372+
)
373+
374+
375+
async def test_invariant_no_empty_dicts_in_parameter_strikes_after_restore(
376+
redis_url: str, strike_name: str
377+
):
378+
"""After restoring all parameter strikes, no empty sets should remain."""
379+
async with StrikeList(url=redis_url, name=strike_name) as strikes:
380+
# Strike a parameter (applies to all tasks)
381+
await strikes.strike(parameter="region", operator="==", value="us-west")
382+
await asyncio.sleep(0.1)
383+
384+
# Verify structure exists
385+
assert "region" in strikes.parameter_strikes
386+
assert len(strikes.parameter_strikes["region"]) == 1
387+
388+
# Restore the strike
389+
await strikes.restore(parameter="region", operator="==", value="us-west")
390+
await asyncio.sleep(0.1)
391+
392+
# After restore, no empty set entries should remain
393+
assert "region" not in strikes.parameter_strikes, (
394+
"parameter_strikes should not contain empty parameter entries"
395+
)
396+
397+
398+
async def test_invariant_multiple_strike_restore_cycles(
399+
redis_url: str, strike_name: str
400+
):
401+
"""Multiple strike/restore cycles should not accumulate orphan entries."""
402+
async with StrikeList(url=redis_url, name=strike_name) as strikes:
403+
for cycle in range(5):
404+
# Strike
405+
await strikes.strike(
406+
function="task_a",
407+
parameter="customer_id",
408+
operator="==",
409+
value=f"id-{cycle}",
410+
)
411+
await strikes.strike(
412+
parameter="global_param", operator=">=", value=100 + cycle
413+
)
414+
await asyncio.sleep(0.05)
415+
416+
# Verify strikes are in effect
417+
assert "task_a" in strikes.task_strikes
418+
assert "global_param" in strikes.parameter_strikes
419+
420+
# Restore
421+
await strikes.restore(
422+
function="task_a",
423+
parameter="customer_id",
424+
operator="==",
425+
value=f"id-{cycle}",
426+
)
427+
await strikes.restore(
428+
parameter="global_param", operator=">=", value=100 + cycle
429+
)
430+
await asyncio.sleep(0.05)
431+
432+
# After each cycle, both dicts should be clean
433+
assert "task_a" not in strikes.task_strikes, (
434+
f"Orphan in task_strikes after cycle {cycle}"
435+
)
436+
assert "global_param" not in strikes.parameter_strikes, (
437+
f"Orphan in parameter_strikes after cycle {cycle}"
438+
)
439+
440+
441+
async def test_invariant_strikelist_state_persists_through_context(
442+
redis_url: str, strike_name: str
443+
):
444+
"""StrikeList data structures should persist (not be deleted) after context exit."""
445+
strikes = StrikeList(url=redis_url, name=strike_name)
446+
447+
# Data structures exist before entering context
448+
assert hasattr(strikes, "task_strikes")
449+
assert hasattr(strikes, "parameter_strikes")
450+
assert hasattr(strikes, "_conditions")
451+
452+
await strikes.__aenter__()
453+
454+
# Add some data
455+
await strikes.strike(parameter="test_param", operator="==", value="test_value")
456+
await asyncio.sleep(0.1)
457+
458+
await strikes.__aexit__(None, None, None)
459+
460+
# Data structures should still exist (not deleted like Worker)
461+
# This is intentional - StrikeList maintains state
462+
assert hasattr(strikes, "task_strikes")
463+
assert hasattr(strikes, "parameter_strikes")
464+
assert hasattr(strikes, "_conditions")
465+
466+
# But connection-related state should be cleaned up
467+
assert strikes._monitor_task is None
468+
assert strikes._strikes_loaded is None

tests/worker/test_invariants.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""Tests for worker internal state invariants.
2+
3+
These tests verify that internal data structures are properly cleaned up after
4+
task completion and context exit, catching memory leaks early in CI.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import asyncio
10+
import logging
11+
from datetime import timedelta
12+
13+
import pytest
14+
15+
from docket import CurrentExecution, Docket, Perpetual, TaskLogger, Timeout, Worker
16+
from docket.dependencies import SharedContext
17+
from docket.execution import Execution
18+
19+
20+
async def test_invariant_tasks_by_key_empty_after_completion(docket: Docket):
21+
"""After run_until_finished, _tasks_by_key should be empty (all tasks done)."""
22+
23+
async def simple_task():
24+
pass
25+
26+
docket.register(simple_task)
27+
28+
async with Worker(docket, concurrency=4) as worker:
29+
for _ in range(50):
30+
await docket.add(simple_task)()
31+
32+
await worker.run_until_finished()
33+
34+
# After completion, _tasks_by_key should be empty
35+
assert len(worker._tasks_by_key) == 0 # type: ignore[protected-access]
36+
37+
38+
async def test_invariant_tasks_by_key_no_growth_over_batches(docket: Docket):
39+
"""Running multiple batches should not accumulate entries in _tasks_by_key."""
40+
41+
async def simple_task():
42+
pass
43+
44+
docket.register(simple_task)
45+
46+
async with Worker(docket, concurrency=4) as worker:
47+
for batch in range(5):
48+
for _ in range(20):
49+
await docket.add(simple_task)()
50+
await worker.run_until_finished()
51+
52+
# After each batch, verify cleanup
53+
assert len(worker._tasks_by_key) == 0, f"Leak after batch {batch}" # type: ignore[protected-access]
54+
55+
56+
async def test_invariant_execution_counts_empty_after_completion(docket: Docket):
57+
"""_execution_counts should be empty after normal run_until_finished (no run_at_most)."""
58+
59+
async def simple_task():
60+
pass
61+
62+
docket.register(simple_task)
63+
64+
async with Worker(docket, concurrency=4) as worker:
65+
for _ in range(10):
66+
await docket.add(simple_task)()
67+
68+
await worker.run_until_finished()
69+
70+
# After normal completion, _execution_counts should be empty
71+
assert len(worker._execution_counts) == 0 # type: ignore[protected-access]
72+
73+
74+
async def test_invariant_execution_counts_cleared_after_run_at_most(docket: Docket):
75+
"""_execution_counts should be cleared after run_at_most completes."""
76+
iteration_count = 0
77+
78+
async def perpetual_task(
79+
perpetual: Perpetual = Perpetual(every=timedelta(milliseconds=10)),
80+
):
81+
nonlocal iteration_count
82+
iteration_count += 1
83+
84+
docket.register(perpetual_task)
85+
await docket.add(perpetual_task, key="test-perpetual")()
86+
87+
async with Worker(docket, concurrency=1) as worker:
88+
await worker.run_at_most({"test-perpetual": 3})
89+
90+
# run_at_most clears _execution_counts in its finally block
91+
assert len(worker._execution_counts) == 0 # type: ignore[protected-access]
92+
93+
assert iteration_count == 3
94+
95+
96+
async def test_invariant_worker_attributes_deleted_after_exit(docket: Docket):
97+
"""Worker internal attributes should be deleted after context exit."""
98+
worker = Worker(docket)
99+
await worker.__aenter__()
100+
101+
# Attributes exist during context
102+
assert hasattr(worker, "_tasks_by_key")
103+
assert hasattr(worker, "_execution_counts")
104+
assert hasattr(worker, "_worker_stopping")
105+
assert hasattr(worker, "_worker_done")
106+
assert hasattr(worker, "_cancellation_ready")
107+
assert hasattr(worker, "_heartbeat_task")
108+
assert hasattr(worker, "_shared_context")
109+
110+
await worker.__aexit__(None, None, None)
111+
112+
# Attributes cleaned up after exit
113+
assert not hasattr(worker, "_tasks_by_key")
114+
assert not hasattr(worker, "_execution_counts")
115+
assert not hasattr(worker, "_worker_stopping")
116+
assert not hasattr(worker, "_worker_done")
117+
assert not hasattr(worker, "_cancellation_ready")
118+
assert not hasattr(worker, "_heartbeat_task")
119+
assert not hasattr(worker, "_shared_context")
120+
assert not hasattr(worker, "_stack")
121+
122+
123+
async def test_invariant_cleanup_after_task_exceptions(docket: Docket):
124+
"""_tasks_by_key should be cleaned up even when tasks raise exceptions."""
125+
126+
async def failing_task():
127+
raise ValueError("intentional failure")
128+
129+
docket.register(failing_task)
130+
131+
async with Worker(docket, concurrency=4) as worker:
132+
for _ in range(10):
133+
await docket.add(failing_task)()
134+
135+
await worker.run_until_finished()
136+
137+
# Even with failures, _tasks_by_key should be empty
138+
assert len(worker._tasks_by_key) == 0 # type: ignore[protected-access]
139+
140+
141+
async def test_invariant_cleanup_with_varied_tasks(docket: Docket):
142+
"""Cleanup should work with all task types: deps, timeouts, returns, kwargs."""
143+
144+
async def simple_task():
145+
pass
146+
147+
async def task_with_deps(
148+
execution: Execution = CurrentExecution(),
149+
logger: logging.LoggerAdapter[logging.Logger] = TaskLogger(),
150+
):
151+
logger.info(f"Running {execution.key}")
152+
153+
async def task_with_timeout(
154+
timeout: Timeout = Timeout(timedelta(seconds=5)),
155+
):
156+
await asyncio.sleep(0.01)
157+
158+
async def task_with_return() -> str:
159+
return "result"
160+
161+
async def task_with_kwargs(a: int, b: str = "default"):
162+
pass
163+
164+
for task in [
165+
simple_task,
166+
task_with_deps,
167+
task_with_timeout,
168+
task_with_return,
169+
task_with_kwargs,
170+
]:
171+
docket.register(task)
172+
173+
async with Worker(docket, concurrency=4) as worker:
174+
# Add varied tasks
175+
for _ in range(5):
176+
await docket.add(simple_task)()
177+
await docket.add(task_with_deps)()
178+
await docket.add(task_with_timeout)()
179+
await docket.add(task_with_return)()
180+
await docket.add(task_with_kwargs)(a=1, b="test")
181+
182+
await worker.run_until_finished()
183+
184+
# Verify cleanup regardless of task type
185+
assert len(worker._tasks_by_key) == 0 # type: ignore[protected-access]
186+
187+
188+
async def test_invariant_shared_context_reset_after_worker_exit(docket: Docket):
189+
"""SharedContext ContextVars should be reset after worker exits."""
190+
worker = Worker(docket)
191+
await worker.__aenter__()
192+
193+
# During context, resolved should have a value
194+
resolved = SharedContext.resolved.get()
195+
assert isinstance(resolved, dict)
196+
197+
await worker.__aexit__(None, None, None)
198+
199+
# After exit, resolved should be reset (LookupError since no value was set before)
200+
with pytest.raises(LookupError):
201+
SharedContext.resolved.get()

0 commit comments

Comments
 (0)