Skip to content

Commit cb08d50

Browse files
petrprikrylpetr.prikryl
andauthored
reliable prefork detection (#10023)
* reliable prefork detection * copilot feedback * better coverage --------- Co-authored-by: petr.prikryl <petr.prikryl@olc.cz>
1 parent cc3350e commit cb08d50

2 files changed

Lines changed: 53 additions & 6 deletions

File tree

celery/fixups/django.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from celery import _state, signals
1414
from celery.exceptions import FixupWarning, ImproperlyConfigured
15+
from celery.worker import WorkController
1516

1617
if TYPE_CHECKING:
1718
from types import ModuleType
@@ -102,6 +103,16 @@ def on_import_modules(self, **kwargs: Any) -> None:
102103
self.worker_fixup.validate_models()
103104

104105
def on_worker_init(self, **kwargs: Any) -> None:
106+
worker: Optional["WorkController"] = kwargs.get("sender")
107+
if worker:
108+
self.worker_fixup.worker = worker
109+
else:
110+
warnings.warn(
111+
"DjangoFixup.on_worker_init called without a sender (worker instance). "
112+
"This may indicate a misconfiguration or an internal error.",
113+
FixupWarning,
114+
stacklevel=2,
115+
)
105116
self.worker_fixup.install()
106117

107118
def now(self, utc: bool = False) -> datetime:
@@ -119,8 +130,9 @@ def _now(self) -> datetime:
119130
class DjangoWorkerFixup:
120131
_db_recycles = 0
121132

122-
def __init__(self, app: "Celery") -> None:
133+
def __init__(self, app: "Celery", worker: Optional["WorkController"] = None) -> None:
123134
self.app = app
135+
self.worker = worker or WorkController(app)
124136
self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None)
125137
self._db = cast("DjangoDBModule", import_module('django.db'))
126138
self._cache = import_module('django.core.cache')
@@ -205,7 +217,7 @@ def _close_database(self) -> None:
205217
# Support Django < 4.1
206218
connections = self._db.connections.all()
207219

208-
is_prefork = self.app.conf.get('worker_pool', 'prefork') == "prefork"
220+
is_prefork = "prefork" in self.worker.pool_cls.__module__
209221

210222
for conn in connections:
211223
try:

t/unit/fixups/test_django.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55

6+
from celery.concurrency.thread import TaskPool as ThreadTaskPool
67
from celery.fixups.django import DjangoFixup, DjangoWorkerFixup, FixupWarning, _maybe_close_fd, fixup
78
from t.unit import conftest
89

@@ -11,11 +12,11 @@ class FixupCase:
1112
Fixup = None
1213

1314
@contextmanager
14-
def fixup_context(self, app):
15+
def fixup_context(self, app, **kwargs):
1516
with patch('celery.fixups.django.DjangoWorkerFixup.validate_models'):
1617
with patch('celery.fixups.django.symbol_by_name') as symbyname:
1718
with patch('celery.fixups.django.import_module') as impmod:
18-
f = self.Fixup(app)
19+
f = self.Fixup(app, **kwargs)
1920
yield f, impmod, symbyname
2021

2122

@@ -150,11 +151,20 @@ def test_now(self):
150151
def test_on_worker_init(self):
151152
with self.fixup_context(self.app) as (f, _, _):
152153
with patch('celery.fixups.django.DjangoWorkerFixup') as DWF:
153-
f.on_worker_init()
154+
mock_worker = Mock(name="worker")
155+
f.on_worker_init(sender=mock_worker)
156+
assert DWF.return_value.worker == mock_worker
157+
154158
DWF.assert_called_with(f.app)
155159
DWF.return_value.install.assert_called_with()
156160
assert f._worker_fixup is DWF.return_value
157161

162+
def test_on_worker_init_warns_without_sender(self):
163+
with self.fixup_context(self.app) as (f, _, _):
164+
with patch("celery.fixups.django.DjangoWorkerFixup"):
165+
with pytest.warns(FixupWarning, match="called without a sender"):
166+
f.on_worker_init(sender=None)
167+
158168

159169
class InterfaceError(Exception):
160170
pass
@@ -168,9 +178,11 @@ def test_init(self):
168178
assert f
169179

170180
def test_install(self):
181+
worker = Mock()
182+
worker.pool_cls = Mock(__module__='celery.concurrency.prefork')
171183
self.app.conf = {'CELERY_DB_REUSE_MAX': None}
172184
self.app.loader = Mock()
173-
with self.fixup_context(self.app) as (f, _, _):
185+
with self.fixup_context(self.app, worker=worker) as (f, _, _):
174186
with patch('celery.fixups.django.signals') as sigs:
175187
f.install()
176188
sigs.beat_embedded_init.connect.assert_called_with(
@@ -336,6 +348,29 @@ class DJSettings:
336348
conn.close.assert_called_once_with()
337349
conn.close_pool.assert_not_called()
338350

351+
def test_close_database_conn_pool_thread_pool(self):
352+
class DJSettings:
353+
DATABASES = {}
354+
355+
with self.fixup_context(self.app) as (f, _, _):
356+
conn = Mock()
357+
conn.alias = "default"
358+
conn.close_pool = Mock()
359+
f._db.connections.all = Mock(return_value=[conn])
360+
f._settings = DJSettings
361+
362+
f._settings.DATABASES["default"] = {"OPTIONS": {"pool": True}}
363+
f.close_database()
364+
conn.close.assert_called_once_with()
365+
conn.close_pool.assert_called_once_with()
366+
367+
conn.reset_mock()
368+
f.worker.pool_cls = ThreadTaskPool
369+
assert "prefork" not in ThreadTaskPool.__module__
370+
f.close_database()
371+
conn.close.assert_called_once_with()
372+
conn.close_pool.assert_not_called()
373+
339374
def test_close_cache_raises_error(self):
340375
with self.fixup_context(self.app) as (f, _, _):
341376
f._cache.close_caches.side_effect = AttributeError

0 commit comments

Comments
 (0)