@@ -1207,38 +1207,38 @@ async def _start(self, timeout=no_default, **kwargs):
12071207
12081208 return self
12091209
1210+ @log_errors
12101211 async def _reconnect (self ):
1211- with log_errors ():
1212- assert self .scheduler_comm .comm .closed ()
1213-
1214- self .status = "connecting"
1215- self .scheduler_comm = None
1212+ assert self .scheduler_comm .comm .closed ()
12161213
1217- for st in self .futures .values ():
1218- st .cancel ()
1219- self .futures .clear ()
1214+ self .status = "connecting"
1215+ self .scheduler_comm = None
12201216
1221- timeout = self ._timeout
1222- deadline = time () + timeout
1223- while timeout > 0 and self .status == "connecting" :
1224- try :
1225- await self ._ensure_connected (timeout = timeout )
1226- break
1227- except OSError :
1228- # Wait a bit before retrying
1229- await asyncio .sleep (0.1 )
1230- timeout = deadline - time ()
1231- except ImportError :
1232- await self ._close ()
1233- break
1217+ for st in self .futures .values ():
1218+ st .cancel ()
1219+ self .futures .clear ()
12341220
1235- else :
1236- logger .error (
1237- "Failed to reconnect to scheduler after %.2f "
1238- "seconds, closing client" ,
1239- self ._timeout ,
1240- )
1221+ timeout = self ._timeout
1222+ deadline = time () + timeout
1223+ while timeout > 0 and self .status == "connecting" :
1224+ try :
1225+ await self ._ensure_connected (timeout = timeout )
1226+ break
1227+ except OSError :
1228+ # Wait a bit before retrying
1229+ await asyncio .sleep (0.1 )
1230+ timeout = deadline - time ()
1231+ except ImportError :
12411232 await self ._close ()
1233+ break
1234+
1235+ else :
1236+ logger .error (
1237+ "Failed to reconnect to scheduler after %.2f "
1238+ "seconds, closing client" ,
1239+ self ._timeout ,
1240+ )
1241+ await self ._close ()
12421242
12431243 async def _ensure_connected (self , timeout = None ):
12441244 if (
@@ -1351,10 +1351,10 @@ async def __aenter__(self):
13511351 await self
13521352 return self
13531353
1354- async def __aexit__ (self , typ , value , traceback ):
1354+ async def __aexit__ (self , exc_type , exc_value , traceback ):
13551355 await self ._close ()
13561356
1357- def __exit__ (self , type , value , traceback ):
1357+ def __exit__ (self , exc_type , exc_value , traceback ):
13581358 self .close ()
13591359
13601360 def __del__ (self ):
@@ -1385,54 +1385,54 @@ def _release_key(self, key):
13851385 {"op" : "client-releases-keys" , "keys" : [key ], "client" : self .id }
13861386 )
13871387
1388+ @log_errors
13881389 async def _handle_report (self ):
13891390 """Listen to scheduler"""
1390- with log_errors ():
1391- try :
1392- while True :
1393- if self .scheduler_comm is None :
1391+ try :
1392+ while True :
1393+ if self .scheduler_comm is None :
1394+ break
1395+ try :
1396+ msgs = await self .scheduler_comm .comm .read ()
1397+ except CommClosedError :
1398+ if is_python_shutting_down ():
1399+ return
1400+ if self .status == "running" :
1401+ logger .info ("Client report stream closed to scheduler" )
1402+ logger .info ("Reconnecting..." )
1403+ self .status = "connecting"
1404+ await self ._reconnect ()
1405+ continue
1406+ else :
13941407 break
1395- try :
1396- msgs = await self .scheduler_comm .comm .read ()
1397- except CommClosedError :
1398- if is_python_shutting_down ():
1399- return
1400- if self .status == "running" :
1401- logger .info ("Client report stream closed to scheduler" )
1402- logger .info ("Reconnecting..." )
1403- self .status = "connecting"
1404- await self ._reconnect ()
1405- continue
1406- else :
1407- break
1408- if not isinstance (msgs , (list , tuple )):
1409- msgs = (msgs ,)
1410-
1411- breakout = False
1412- for msg in msgs :
1413- logger .debug ("Client receives message %s" , msg )
1408+ if not isinstance (msgs , (list , tuple )):
1409+ msgs = (msgs ,)
14141410
1415- if "status" in msg and "error" in msg [ "status" ]:
1416- typ , exc , tb = clean_exception ( ** msg )
1417- raise exc . with_traceback ( tb )
1411+ breakout = False
1412+ for msg in msgs :
1413+ logger . debug ( "Client receives message %s" , msg )
14181414
1419- op = msg .pop ("op" )
1415+ if "status" in msg and "error" in msg ["status" ]:
1416+ typ , exc , tb = clean_exception (** msg )
1417+ raise exc .with_traceback (tb )
14201418
1421- if op == "close" or op == "stream-closed" :
1422- breakout = True
1423- break
1419+ op = msg .pop ("op" )
14241420
1425- try :
1426- handler = self ._stream_handlers [op ]
1427- result = handler (** msg )
1428- if inspect .isawaitable (result ):
1429- await result
1430- except Exception as e :
1431- logger .exception (e )
1432- if breakout :
1421+ if op == "close" or op == "stream-closed" :
1422+ breakout = True
14331423 break
1434- except CancelledError :
1435- pass
1424+
1425+ try :
1426+ handler = self ._stream_handlers [op ]
1427+ result = handler (** msg )
1428+ if inspect .isawaitable (result ):
1429+ await result
1430+ except Exception as e :
1431+ logger .exception (e )
1432+ if breakout :
1433+ break
1434+ except CancelledError :
1435+ pass
14361436
14371437 def _handle_key_in_memory (self , key = None , type = None , workers = None ):
14381438 state = self .futures .get (key )
@@ -2444,37 +2444,37 @@ def retry(self, futures, asynchronous=None):
24442444 """
24452445 return self .sync (self ._retry , futures , asynchronous = asynchronous )
24462446
2447+ @log_errors
24472448 async def _publish_dataset (self , * args , name = None , override = False , ** kwargs ):
2448- with log_errors ():
2449- coroutines = []
2450-
2451- def add_coro (name , data ):
2452- keys = [stringify (f .key ) for f in futures_of (data )]
2453- coroutines .append (
2454- self .scheduler .publish_put (
2455- keys = keys ,
2456- name = name ,
2457- data = to_serialize (data ),
2458- override = override ,
2459- client = self .id ,
2460- )
2449+ coroutines = []
2450+
2451+ def add_coro (name , data ):
2452+ keys = [stringify (f .key ) for f in futures_of (data )]
2453+ coroutines .append (
2454+ self .scheduler .publish_put (
2455+ keys = keys ,
2456+ name = name ,
2457+ data = to_serialize (data ),
2458+ override = override ,
2459+ client = self .id ,
24612460 )
2461+ )
24622462
2463- if name :
2464- if len (args ) == 0 :
2465- raise ValueError (
2466- "If name is provided, expecting call signature like"
2467- " publish_dataset(df, name='ds')"
2468- )
2469- # in case this is a singleton, collapse it
2470- elif len (args ) == 1 :
2471- args = args [0 ]
2472- add_coro (name , args )
2463+ if name :
2464+ if len (args ) == 0 :
2465+ raise ValueError (
2466+ "If name is provided, expecting call signature like"
2467+ " publish_dataset(df, name='ds')"
2468+ )
2469+ # in case this is a singleton, collapse it
2470+ elif len (args ) == 1 :
2471+ args = args [0 ]
2472+ add_coro (name , args )
24732473
2474- for name , data in kwargs .items ():
2475- add_coro (name , data )
2474+ for name , data in kwargs .items ():
2475+ add_coro (name , data )
24762476
2477- await asyncio .gather (* coroutines )
2477+ await asyncio .gather (* coroutines )
24782478
24792479 def publish_dataset (self , * args , ** kwargs ):
24802480 """
@@ -5173,7 +5173,7 @@ def __enter__(self):
51735173 self .start = time ()
51745174 return self
51755175
5176- def __exit__ (self , typ , value , traceback ):
5176+ def __exit__ (self , exc_type , exc_value , traceback ):
51775177 L = self .client .get_task_stream (
51785178 start = self .start , plot = self ._plot , filename = self ._filename
51795179 )
@@ -5184,7 +5184,7 @@ def __exit__(self, typ, value, traceback):
51845184 async def __aenter__ (self ):
51855185 return self
51865186
5187- async def __aexit__ (self , typ , value , traceback ):
5187+ async def __aexit__ (self , exc_type , exc_value , traceback ):
51885188 L = await self .client .get_task_stream (
51895189 start = self .start , plot = self ._plot , filename = self ._filename
51905190 )
@@ -5237,7 +5237,7 @@ async def __aenter__(self):
52375237 )
52385238 await get_client ().get_task_stream (start = 0 , stop = 0 ) # ensure plugin
52395239
5240- async def __aexit__ (self , typ , value , traceback , code = None ):
5240+ async def __aexit__ (self , exc_type , exc_value , traceback , code = None ):
52415241 client = get_client ()
52425242 if code is None :
52435243 code = client ._get_computation_code (self ._stacklevel + 1 )
@@ -5250,10 +5250,10 @@ async def __aexit__(self, typ, value, traceback, code=None):
52505250 def __enter__ (self ):
52515251 get_client ().sync (self .__aenter__ )
52525252
5253- def __exit__ (self , typ , value , traceback ):
5253+ def __exit__ (self , exc_type , exc_value , traceback ):
52545254 client = get_client ()
52555255 code = client ._get_computation_code (self ._stacklevel + 1 )
5256- client .sync (self .__aexit__ , type , value , traceback , code = code )
5256+ client .sync (self .__aexit__ , exc_type , exc_value , traceback , code = code )
52575257
52585258
52595259class get_task_metadata :
@@ -5283,16 +5283,16 @@ async def __aenter__(self):
52835283 await get_client ().scheduler .start_task_metadata (name = self .name )
52845284 return self
52855285
5286- async def __aexit__ (self , typ , value , traceback ):
5286+ async def __aexit__ (self , exc_type , exc_value , traceback ):
52875287 response = await get_client ().scheduler .stop_task_metadata (name = self .name )
52885288 self .metadata = response ["metadata" ]
52895289 self .state = response ["state" ]
52905290
52915291 def __enter__ (self ):
52925292 return get_client ().sync (self .__aenter__ )
52935293
5294- def __exit__ (self , typ , value , traceback ):
5295- return get_client ().sync (self .__aexit__ , type , value , traceback )
5294+ def __exit__ (self , exc_type , exc_value , traceback ):
5295+ return get_client ().sync (self .__aexit__ , exc_type , exc_value , traceback )
52965296
52975297
52985298@contextmanager
0 commit comments