33
44import pytest
55
6+ from celery .concurrency .thread import TaskPool as ThreadTaskPool
67from celery .fixups .django import DjangoFixup , DjangoWorkerFixup , FixupWarning , _maybe_close_fd , fixup
78from t .unit import conftest
89
@@ -11,11 +12,11 @@ class FixupCase:
1112 Fixup = None
1213
1314 @contextmanager
14- def fixup_context (self , app ):
15+ def fixup_context (self , app , ** kwargs ):
1516 with patch ('celery.fixups.django.DjangoWorkerFixup.validate_models' ):
1617 with patch ('celery.fixups.django.symbol_by_name' ) as symbyname :
1718 with patch ('celery.fixups.django.import_module' ) as impmod :
18- f = self .Fixup (app )
19+ f = self .Fixup (app , ** kwargs )
1920 yield f , impmod , symbyname
2021
2122
@@ -150,11 +151,20 @@ def test_now(self):
150151 def test_on_worker_init (self ):
151152 with self .fixup_context (self .app ) as (f , _ , _ ):
152153 with patch ('celery.fixups.django.DjangoWorkerFixup' ) as DWF :
153- f .on_worker_init ()
154+ mock_worker = Mock (name = "worker" )
155+ f .on_worker_init (sender = mock_worker )
156+ assert DWF .return_value .worker == mock_worker
157+
154158 DWF .assert_called_with (f .app )
155159 DWF .return_value .install .assert_called_with ()
156160 assert f ._worker_fixup is DWF .return_value
157161
162+ def test_on_worker_init_warns_without_sender (self ):
163+ with self .fixup_context (self .app ) as (f , _ , _ ):
164+ with patch ("celery.fixups.django.DjangoWorkerFixup" ):
165+ with pytest .warns (FixupWarning , match = "called without a sender" ):
166+ f .on_worker_init (sender = None )
167+
158168
159169class InterfaceError (Exception ):
160170 pass
@@ -168,9 +178,11 @@ def test_init(self):
168178 assert f
169179
170180 def test_install (self ):
181+ worker = Mock ()
182+ worker .pool_cls = Mock (__module__ = 'celery.concurrency.prefork' )
171183 self .app .conf = {'CELERY_DB_REUSE_MAX' : None }
172184 self .app .loader = Mock ()
173- with self .fixup_context (self .app ) as (f , _ , _ ):
185+ with self .fixup_context (self .app , worker = worker ) as (f , _ , _ ):
174186 with patch ('celery.fixups.django.signals' ) as sigs :
175187 f .install ()
176188 sigs .beat_embedded_init .connect .assert_called_with (
@@ -336,6 +348,29 @@ class DJSettings:
336348 conn .close .assert_called_once_with ()
337349 conn .close_pool .assert_not_called ()
338350
351+ def test_close_database_conn_pool_thread_pool (self ):
352+ class DJSettings :
353+ DATABASES = {}
354+
355+ with self .fixup_context (self .app ) as (f , _ , _ ):
356+ conn = Mock ()
357+ conn .alias = "default"
358+ conn .close_pool = Mock ()
359+ f ._db .connections .all = Mock (return_value = [conn ])
360+ f ._settings = DJSettings
361+
362+ f ._settings .DATABASES ["default" ] = {"OPTIONS" : {"pool" : True }}
363+ f .close_database ()
364+ conn .close .assert_called_once_with ()
365+ conn .close_pool .assert_called_once_with ()
366+
367+ conn .reset_mock ()
368+ f .worker .pool_cls = ThreadTaskPool
369+ assert "prefork" not in ThreadTaskPool .__module__
370+ f .close_database ()
371+ conn .close .assert_called_once_with ()
372+ conn .close_pool .assert_not_called ()
373+
339374 def test_close_cache_raises_error (self ):
340375 with self .fixup_context (self .app ) as (f , _ , _ ):
341376 f ._cache .close_caches .side_effect = AttributeError
0 commit comments