Skip to content

Commit e736c0b

Browse files
authored
Use worker comm pool in Semaphore (#4195)
* Semaphore uses worker comm pool * Switch semaphore logging to debug level * Align usage of loop and scheduler attribute names in Semaphore
1 parent 9442d9b commit e736c0b

2 files changed

Lines changed: 134 additions & 88 deletions

File tree

distributed/semaphore.py

Lines changed: 97 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from collections import defaultdict, deque
77

88
import dask
9-
from tornado.ioloop import PeriodicCallback
9+
from tornado.ioloop import IOLoop, PeriodicCallback
1010

1111
from distributed.utils_comm import retry_operation
12+
1213
from .metrics import time
13-
from .utils import log_errors, parse_timedelta
14-
from .worker import get_client
14+
from .utils import log_errors, parse_timedelta, sync, thread_state
15+
from .worker import get_client, get_worker
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -130,7 +131,7 @@ def _get_lease(self, name, lease_id):
130131
or len(self.leases[name]) < self.max_leases[name]
131132
):
132133
now = time()
133-
logger.info("Acquire lease %s for %s at %s", lease_id, name, now)
134+
logger.debug("Acquire lease %s for %s at %s", lease_id, name, now)
134135
self.leases[name][lease_id] = now
135136
self.metrics["acquire_total"][name] += 1
136137
else:
@@ -154,8 +155,8 @@ async def acquire(self, comm=None, name=None, timeout=None, lease_id=None):
154155

155156
self.metrics["pending"][name] += 1
156157
while True:
157-
logger.info(
158-
"Trying to acquire %s for %s with %ss left.",
158+
logger.debug(
159+
"Trying to acquire %s for %s with %s seconds left.",
159160
lease_id,
160161
name,
161162
w.leftover(),
@@ -177,7 +178,7 @@ async def acquire(self, comm=None, name=None, timeout=None, lease_id=None):
177178
continue
178179
except TimeoutError:
179180
result = False
180-
logger.info(
181+
logger.debug(
181182
"Acquisition of lease %s for %s is %s after waiting for %ss.",
182183
lease_id,
183184
name,
@@ -210,7 +211,7 @@ def release(self, comm=None, name=None, lease_id=None):
210211
)
211212

212213
def _release_value(self, name, lease_id):
213-
logger.info("Releasing %s for %s", lease_id, name)
214+
logger.debug("Releasing %s for %s", lease_id, name)
214215
# Everything needs to be atomic here.
215216
del self.leases[name][lease_id]
216217
self.events[name].set()
@@ -230,7 +231,7 @@ def _check_lease_timeout(self):
230231
for _id in ids:
231232
time_since_refresh = now - self.leases[name][_id]
232233
if time_since_refresh > self.lease_timeout:
233-
logger.info(
234+
logger.debug(
234235
"Lease %s for %s timed out after %ss.",
235236
_id,
236237
name,
@@ -311,15 +312,19 @@ class Semaphore:
311312
Name of the semaphore to acquire. Choosing the same name allows two
312313
disconnected processes to coordinate. If not given, a random
313314
name will be generated.
314-
client: Client (optional)
315-
Client to use for communication with the scheduler. If not given, the
316-
default global client will be used.
317315
register: bool
318316
If True, register the semaphore with the scheduler. This needs to be
319317
done before any leases can be acquired. If not done during
320318
initialization, this can also be done by calling the register method of
321319
this class.
322320
When registering, this needs to be awaited.
321+
scheduler_rpc: ConnectionPool
322+
The ConnectionPool to connect to the scheduler. If None is provided, it
323+
uses the worker or client pool. This paramter is mostly used for
324+
testing.
325+
loop: IOLoop
326+
The event loop this instance is using. If None is provided, reuse the
327+
loop of the active worker or client.
323328
324329
Examples
325330
--------
@@ -355,8 +360,25 @@ class Semaphore:
355360
356361
"""
357362

358-
def __init__(self, max_leases=1, name=None, client=None, register=True):
359-
self.client = client or get_client()
363+
def __init__(
364+
self,
365+
max_leases=1,
366+
name=None,
367+
register=True,
368+
scheduler_rpc=None,
369+
loop=None,
370+
):
371+
372+
try:
373+
worker = get_worker()
374+
self.scheduler = scheduler_rpc or worker.scheduler
375+
self.loop = loop or worker.loop
376+
377+
except ValueError:
378+
client = get_client()
379+
self.scheduler = scheduler_rpc or client.scheduler
380+
self.loop = loop or client.io_loop
381+
360382
self.name = name or "semaphore-" + uuid.uuid4().hex
361383
self.max_leases = max_leases
362384
self.id = uuid.uuid4().hex
@@ -381,27 +403,25 @@ def __init__(self, max_leases=1, name=None, client=None, register=True):
381403
self._refresh_leases, callback_time=refresh_leases_interval * 1000
382404
)
383405
self.refresh_callback = pc
384-
# Registering the pc to the client here is important for proper cleanup
385-
self._periodic_callback_name = f"refresh_semaphores_{self.id}"
386-
self.client._periodic_callbacks[self._periodic_callback_name] = pc
387406

388407
# Need to start the callback using IOLoop.add_callback to ensure that the
389408
# PC uses the correct event loop.
390-
self.client.io_loop.add_callback(pc.start)
409+
self.loop.add_callback(pc.start)
391410

392-
def register(self):
393-
"""
394-
Register the semaphore on scheduler side
411+
@property
412+
def asynchronous(self):
413+
return self.loop is IOLoop.current()
395414

396-
This will register the semaphore on scheduler side and ensure that all necessary data structures exist.
397-
"""
398-
if self._registered is None:
399-
self._registered = self.client.sync(
400-
self.client.scheduler.semaphore_register,
401-
name=self.name,
402-
max_leases=self.max_leases,
403-
)
404-
return self._registered
415+
async def _register(self):
416+
await retry_operation(
417+
self.scheduler.semaphore_register,
418+
name=self.name,
419+
max_leases=self.max_leases,
420+
operation=f"semaphore register id={self.id} name={self.name}",
421+
)
422+
423+
def register(self, **kwargs):
424+
return self.sync(self._register)
405425

406426
def __await__(self):
407427
async def create_semaphore():
@@ -411,34 +431,53 @@ async def create_semaphore():
411431

412432
return create_semaphore().__await__()
413433

434+
def sync(self, func, *args, asynchronous=None, callback_timeout=None, **kwargs):
435+
callback_timeout = parse_timedelta(callback_timeout)
436+
if (
437+
asynchronous
438+
or self.asynchronous
439+
or getattr(thread_state, "asynchronous", False)
440+
):
441+
future = func(*args, **kwargs)
442+
if callback_timeout is not None:
443+
future = asyncio.wait_for(future, callback_timeout)
444+
return future
445+
else:
446+
return sync(
447+
self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
448+
)
449+
414450
async def _refresh_leases(self):
415451
if self.refresh_leases and self._leases:
416452
logger.debug(
417453
"%s refreshing leases for %s with IDs %s",
418-
self.client.id,
454+
self.id,
419455
self.name,
420456
self._leases,
421457
)
422-
await self.client.scheduler.semaphore_refresh_leases(
423-
lease_ids=list(self._leases), name=self.name
458+
await retry_operation(
459+
self.scheduler.semaphore_refresh_leases,
460+
lease_ids=list(self._leases),
461+
name=self.name,
462+
operation="semaphore refresh leases: id=%s, lease_ids=%s, name=%s"
463+
% (self.id, list(self._leases), self.name),
424464
)
425465

426466
async def _acquire(self, timeout=None):
427467
lease_id = uuid.uuid4().hex
428-
logger.info(
429-
"%s requests lease for %s with ID %s", self.client.id, self.name, lease_id
468+
logger.debug(
469+
"%s requests lease for %s with ID %s", self.id, self.name, lease_id
430470
)
431471

432472
# Using a unique lease id generated here allows us to retry since the
433473
# server handle is idempotent
434-
435474
result = await retry_operation(
436-
self.client.scheduler.semaphore_acquire,
475+
self.scheduler.semaphore_acquire,
437476
name=self.name,
438477
timeout=timeout,
439478
lease_id=lease_id,
440-
operation="semaphore acquire: client=%s, lease_id=%s, name=%s"
441-
% (self.client.id, lease_id, self.name),
479+
operation="semaphore acquire: id=%s, lease_id=%s, name=%s"
480+
% (self.id, lease_id, self.name),
442481
)
443482
if result:
444483
self._leases.append(lease_id)
@@ -460,26 +499,22 @@ def acquire(self, timeout=None):
460499
a timedelta in string format, e.g. "200ms".
461500
"""
462501
timeout = parse_timedelta(timeout)
463-
return self.client.sync(self._acquire, timeout=timeout)
464-
465-
async def _release(self):
466-
# popleft to release the oldest lease first
467-
lease_id = self._leases.popleft()
468-
logger.info("%s releases %s for %s", self.client.id, lease_id, self.name)
502+
return self.sync(self._acquire, timeout=timeout)
469503

504+
async def _release(self, lease_id):
470505
try:
471506
await retry_operation(
472-
self.client.scheduler.semaphore_release,
507+
self.scheduler.semaphore_release,
473508
name=self.name,
474509
lease_id=lease_id,
475-
operation="semaphore release: client=%s, lease_id=%s, name=%s"
476-
% (self.client.id, lease_id, self.name),
510+
operation="semaphore release: id=%s, lease_id=%s, name=%s"
511+
% (self.id, lease_id, self.name),
477512
)
478513
return True
479514
except Exception: # Release fails for whatever reason
480515
logger.error(
481-
"Release failed for client=%s, lease_id=%s, name=%s. Cluster network might be unstable?"
482-
% (self.client.id, lease_id, self.name),
516+
"Release failed for id=%s, lease_id=%s, name=%s. Cluster network might be unstable?"
517+
% (self.id, lease_id, self.name),
483518
exc_info=True,
484519
)
485520
return False
@@ -499,13 +534,16 @@ def release(self):
499534
if not self._leases:
500535
raise RuntimeError("Released too often")
501536

502-
return self.client.sync(self._release)
537+
# popleft to release the oldest lease first
538+
lease_id = self._leases.popleft()
539+
logger.debug("%s releases %s for %s", self.id, lease_id, self.name)
540+
return self.sync(self._release, lease_id=lease_id)
503541

504542
def get_value(self):
505543
"""
506544
Return the number of currently registered leases.
507545
"""
508-
return self.client.sync(self.client.scheduler.semaphore_value, name=self.name)
546+
return self.sync(self.scheduler.semaphore_value, name=self.name)
509547

510548
def __enter__(self):
511549
self.acquire()
@@ -528,13 +566,14 @@ def __getstate__(self):
528566

529567
def __setstate__(self, state):
530568
name, max_leases = state
531-
client = get_client()
532-
self.__init__(name=name, client=client, max_leases=max_leases, register=False)
569+
self.__init__(
570+
name=name,
571+
max_leases=max_leases,
572+
register=False,
573+
)
533574

534575
def close(self):
535-
return self.client.sync(self.client.scheduler.semaphore_close, name=self.name)
576+
return self.sync(self.scheduler.semaphore_close, name=self.name)
536577

537578
def __del__(self):
538-
if self._periodic_callback_name in self.client._periodic_callbacks:
539-
self.client._periodic_callbacks[self._periodic_callback_name].stop()
540-
del self.client._periodic_callbacks[self._periodic_callback_name]
579+
self.refresh_callback.stop()

0 commit comments

Comments
 (0)