11import asyncio
22import json
33import logging
4- from asyncio import Queue
4+ from asyncio import Queue , Lock
55from asyncio .futures import Future
66from logging import Logger
77from typing import Dict , Union , Any , Optional , List , Callable , Awaitable
@@ -23,6 +23,8 @@ class AsyncBaseSocketModeClient:
2323 wss_uri : str
2424 auto_reconnect_enabled : bool
2525 closed : bool
26+ connect_operation_lock : Lock
27+
2628 message_queue : Queue
2729 message_listeners : List [
2830 Union [
@@ -58,15 +60,24 @@ async def issue_new_wss_url(self) -> str:
5860 self .logger .error (f"Failed to retrieve WSS URL: { e } " )
5961 raise e
6062
63+ async def is_connected (self ) -> bool :
64+ return False
65+
6166 async def connect (self ):
6267 raise NotImplementedError ()
6368
6469 async def disconnect (self ):
6570 raise NotImplementedError ()
6671
67- async def connect_to_new_endpoint (self ):
68- self .wss_uri = await self .issue_new_wss_url ()
69- await self .connect ()
72+ async def connect_to_new_endpoint (self , force : bool = False ):
73+ try :
74+ await self .connect_operation_lock .acquire ()
75+ if force or not await self .is_connected ():
76+ self .wss_uri = await self .issue_new_wss_url ()
77+ await self .connect ()
78+ finally :
79+ if self .connect_operation_lock .locked () is True :
80+ self .connect_operation_lock .release ()
7081
7182 async def close (self ):
7283 self .closed = True
@@ -116,7 +127,7 @@ async def run_message_listeners(self, message: dict, raw_message: str) -> None:
116127 )
117128 try :
118129 if message .get ("type" ) == "disconnect" :
119- await self .connect_to_new_endpoint ()
130+ await self .connect_to_new_endpoint (force = True )
120131 return
121132
122133 for listener in self .message_listeners :
0 commit comments