66from collections import defaultdict , deque
77
88import dask
9- from tornado .ioloop import PeriodicCallback
9+ from tornado .ioloop import IOLoop , PeriodicCallback
1010
1111from distributed .utils_comm import retry_operation
12+
1213from .metrics import time
13- from .utils import log_errors , parse_timedelta
14- from .worker import get_client
14+ from .utils import log_errors , parse_timedelta , sync , thread_state
15+ from .worker import get_client , get_worker
1516
1617logger = logging .getLogger (__name__ )
1718
@@ -130,7 +131,7 @@ def _get_lease(self, name, lease_id):
130131 or len (self .leases [name ]) < self .max_leases [name ]
131132 ):
132133 now = time ()
133- logger .info ("Acquire lease %s for %s at %s" , lease_id , name , now )
134+ logger .debug ("Acquire lease %s for %s at %s" , lease_id , name , now )
134135 self .leases [name ][lease_id ] = now
135136 self .metrics ["acquire_total" ][name ] += 1
136137 else :
@@ -154,8 +155,8 @@ async def acquire(self, comm=None, name=None, timeout=None, lease_id=None):
154155
155156 self .metrics ["pending" ][name ] += 1
156157 while True :
157- logger .info (
158- "Trying to acquire %s for %s with %ss left." ,
158+ logger .debug (
159+ "Trying to acquire %s for %s with %s seconds left." ,
159160 lease_id ,
160161 name ,
161162 w .leftover (),
@@ -177,7 +178,7 @@ async def acquire(self, comm=None, name=None, timeout=None, lease_id=None):
177178 continue
178179 except TimeoutError :
179180 result = False
180- logger .info (
181+ logger .debug (
181182 "Acquisition of lease %s for %s is %s after waiting for %ss." ,
182183 lease_id ,
183184 name ,
@@ -210,7 +211,7 @@ def release(self, comm=None, name=None, lease_id=None):
210211 )
211212
212213 def _release_value (self , name , lease_id ):
213- logger .info ("Releasing %s for %s" , lease_id , name )
214+ logger .debug ("Releasing %s for %s" , lease_id , name )
214215 # Everything needs to be atomic here.
215216 del self .leases [name ][lease_id ]
216217 self .events [name ].set ()
@@ -230,7 +231,7 @@ def _check_lease_timeout(self):
230231 for _id in ids :
231232 time_since_refresh = now - self .leases [name ][_id ]
232233 if time_since_refresh > self .lease_timeout :
233- logger .info (
234+ logger .debug (
234235 "Lease %s for %s timed out after %ss." ,
235236 _id ,
236237 name ,
@@ -311,15 +312,19 @@ class Semaphore:
311312 Name of the semaphore to acquire. Choosing the same name allows two
312313 disconnected processes to coordinate. If not given, a random
313314 name will be generated.
314- client: Client (optional)
315- Client to use for communication with the scheduler. If not given, the
316- default global client will be used.
317315 register: bool
318316 If True, register the semaphore with the scheduler. This needs to be
319317 done before any leases can be acquired. If not done during
320318 initialization, this can also be done by calling the register method of
321319 this class.
322320 When registering, this needs to be awaited.
321+ scheduler_rpc: ConnectionPool
322+ The ConnectionPool to connect to the scheduler. If None is provided, it
323+ uses the worker or client pool. This paramter is mostly used for
324+ testing.
325+ loop: IOLoop
326+ The event loop this instance is using. If None is provided, reuse the
327+ loop of the active worker or client.
323328
324329 Examples
325330 --------
@@ -355,8 +360,25 @@ class Semaphore:
355360
356361 """
357362
358- def __init__ (self , max_leases = 1 , name = None , client = None , register = True ):
359- self .client = client or get_client ()
363+ def __init__ (
364+ self ,
365+ max_leases = 1 ,
366+ name = None ,
367+ register = True ,
368+ scheduler_rpc = None ,
369+ loop = None ,
370+ ):
371+
372+ try :
373+ worker = get_worker ()
374+ self .scheduler = scheduler_rpc or worker .scheduler
375+ self .loop = loop or worker .loop
376+
377+ except ValueError :
378+ client = get_client ()
379+ self .scheduler = scheduler_rpc or client .scheduler
380+ self .loop = loop or client .io_loop
381+
360382 self .name = name or "semaphore-" + uuid .uuid4 ().hex
361383 self .max_leases = max_leases
362384 self .id = uuid .uuid4 ().hex
@@ -381,27 +403,25 @@ def __init__(self, max_leases=1, name=None, client=None, register=True):
381403 self ._refresh_leases , callback_time = refresh_leases_interval * 1000
382404 )
383405 self .refresh_callback = pc
384- # Registering the pc to the client here is important for proper cleanup
385- self ._periodic_callback_name = f"refresh_semaphores_{ self .id } "
386- self .client ._periodic_callbacks [self ._periodic_callback_name ] = pc
387406
388407 # Need to start the callback using IOLoop.add_callback to ensure that the
389408 # PC uses the correct event loop.
390- self .client . io_loop .add_callback (pc .start )
409+ self .loop .add_callback (pc .start )
391410
392- def register ( self ):
393- """
394- Register the semaphore on scheduler side
411+ @ property
412+ def asynchronous ( self ):
413+ return self . loop is IOLoop . current ()
395414
396- This will register the semaphore on scheduler side and ensure that all necessary data structures exist.
397- """
398- if self ._registered is None :
399- self ._registered = self .client .sync (
400- self .client .scheduler .semaphore_register ,
401- name = self .name ,
402- max_leases = self .max_leases ,
403- )
404- return self ._registered
415+ async def _register (self ):
416+ await retry_operation (
417+ self .scheduler .semaphore_register ,
418+ name = self .name ,
419+ max_leases = self .max_leases ,
420+ operation = f"semaphore register id={ self .id } name={ self .name } " ,
421+ )
422+
423+ def register (self , ** kwargs ):
424+ return self .sync (self ._register )
405425
406426 def __await__ (self ):
407427 async def create_semaphore ():
@@ -411,34 +431,53 @@ async def create_semaphore():
411431
412432 return create_semaphore ().__await__ ()
413433
434+ def sync (self , func , * args , asynchronous = None , callback_timeout = None , ** kwargs ):
435+ callback_timeout = parse_timedelta (callback_timeout )
436+ if (
437+ asynchronous
438+ or self .asynchronous
439+ or getattr (thread_state , "asynchronous" , False )
440+ ):
441+ future = func (* args , ** kwargs )
442+ if callback_timeout is not None :
443+ future = asyncio .wait_for (future , callback_timeout )
444+ return future
445+ else :
446+ return sync (
447+ self .loop , func , * args , callback_timeout = callback_timeout , ** kwargs
448+ )
449+
414450 async def _refresh_leases (self ):
415451 if self .refresh_leases and self ._leases :
416452 logger .debug (
417453 "%s refreshing leases for %s with IDs %s" ,
418- self .client . id ,
454+ self .id ,
419455 self .name ,
420456 self ._leases ,
421457 )
422- await self .client .scheduler .semaphore_refresh_leases (
423- lease_ids = list (self ._leases ), name = self .name
458+ await retry_operation (
459+ self .scheduler .semaphore_refresh_leases ,
460+ lease_ids = list (self ._leases ),
461+ name = self .name ,
462+ operation = "semaphore refresh leases: id=%s, lease_ids=%s, name=%s"
463+ % (self .id , list (self ._leases ), self .name ),
424464 )
425465
426466 async def _acquire (self , timeout = None ):
427467 lease_id = uuid .uuid4 ().hex
428- logger .info (
429- "%s requests lease for %s with ID %s" , self .client . id , self .name , lease_id
468+ logger .debug (
469+ "%s requests lease for %s with ID %s" , self .id , self .name , lease_id
430470 )
431471
432472 # Using a unique lease id generated here allows us to retry since the
433473 # server handle is idempotent
434-
435474 result = await retry_operation (
436- self .client . scheduler .semaphore_acquire ,
475+ self .scheduler .semaphore_acquire ,
437476 name = self .name ,
438477 timeout = timeout ,
439478 lease_id = lease_id ,
440- operation = "semaphore acquire: client =%s, lease_id=%s, name=%s"
441- % (self .client . id , lease_id , self .name ),
479+ operation = "semaphore acquire: id =%s, lease_id=%s, name=%s"
480+ % (self .id , lease_id , self .name ),
442481 )
443482 if result :
444483 self ._leases .append (lease_id )
@@ -460,26 +499,22 @@ def acquire(self, timeout=None):
460499 a timedelta in string format, e.g. "200ms".
461500 """
462501 timeout = parse_timedelta (timeout )
463- return self .client .sync (self ._acquire , timeout = timeout )
464-
465- async def _release (self ):
466- # popleft to release the oldest lease first
467- lease_id = self ._leases .popleft ()
468- logger .info ("%s releases %s for %s" , self .client .id , lease_id , self .name )
502+ return self .sync (self ._acquire , timeout = timeout )
469503
504+ async def _release (self , lease_id ):
470505 try :
471506 await retry_operation (
472- self .client . scheduler .semaphore_release ,
507+ self .scheduler .semaphore_release ,
473508 name = self .name ,
474509 lease_id = lease_id ,
475- operation = "semaphore release: client =%s, lease_id=%s, name=%s"
476- % (self .client . id , lease_id , self .name ),
510+ operation = "semaphore release: id =%s, lease_id=%s, name=%s"
511+ % (self .id , lease_id , self .name ),
477512 )
478513 return True
479514 except Exception : # Release fails for whatever reason
480515 logger .error (
481- "Release failed for client =%s, lease_id=%s, name=%s. Cluster network might be unstable?"
482- % (self .client . id , lease_id , self .name ),
516+ "Release failed for id =%s, lease_id=%s, name=%s. Cluster network might be unstable?"
517+ % (self .id , lease_id , self .name ),
483518 exc_info = True ,
484519 )
485520 return False
@@ -499,13 +534,16 @@ def release(self):
499534 if not self ._leases :
500535 raise RuntimeError ("Released too often" )
501536
502- return self .client .sync (self ._release )
537+ # popleft to release the oldest lease first
538+ lease_id = self ._leases .popleft ()
539+ logger .debug ("%s releases %s for %s" , self .id , lease_id , self .name )
540+ return self .sync (self ._release , lease_id = lease_id )
503541
504542 def get_value (self ):
505543 """
506544 Return the number of currently registered leases.
507545 """
508- return self .client . sync (self . client .scheduler .semaphore_value , name = self .name )
546+ return self .sync (self .scheduler .semaphore_value , name = self .name )
509547
510548 def __enter__ (self ):
511549 self .acquire ()
@@ -528,13 +566,14 @@ def __getstate__(self):
528566
529567 def __setstate__ (self , state ):
530568 name , max_leases = state
531- client = get_client ()
532- self .__init__ (name = name , client = client , max_leases = max_leases , register = False )
569+ self .__init__ (
570+ name = name ,
571+ max_leases = max_leases ,
572+ register = False ,
573+ )
533574
534575 def close (self ):
535- return self .client . sync (self . client .scheduler .semaphore_close , name = self .name )
576+ return self .sync (self .scheduler .semaphore_close , name = self .name )
536577
537578 def __del__ (self ):
538- if self ._periodic_callback_name in self .client ._periodic_callbacks :
539- self .client ._periodic_callbacks [self ._periodic_callback_name ].stop ()
540- del self .client ._periodic_callbacks [self ._periodic_callback_name ]
579+ self .refresh_callback .stop ()
0 commit comments