@@ -100,6 +100,19 @@ async def _validate_ticket(
100100 return user
101101
102102
103+ async def _reject_auth (
104+ socket : WebSocket [Any , Any , Any ],
105+ log_reason : str ,
106+ close_reason : str ,
107+ * ,
108+ code : int = _WS_CLOSE_AUTH_FAILED ,
109+ ** extra_kwargs : str ,
110+ ) -> None :
111+ """Log a warning and close the socket for an auth rejection."""
112+ logger .warning (API_WS_TICKET_INVALID , reason = log_reason , ** extra_kwargs )
113+ await socket .close (code = code , reason = close_reason )
114+
115+
103116async def _read_auth_message ( # noqa: PLR0911
104117 socket : WebSocket [Any , Any , Any ],
105118) -> str | None :
@@ -113,43 +126,36 @@ async def _read_auth_message( # noqa: PLR0911
113126 timeout = _WS_AUTH_TIMEOUT_SECONDS ,
114127 )
115128 except TimeoutError :
116- logger .warning (API_WS_TICKET_INVALID , reason = "auth_timeout" )
117- await socket .close (code = _WS_CLOSE_AUTH_FAILED , reason = "Auth timeout" )
129+ await _reject_auth (socket , "auth_timeout" , "Auth timeout" )
118130 return None
119131 except WebSocketDisconnect :
120132 logger .debug (API_WS_DISCONNECTED , reason = "disconnect_during_auth" )
121133 return None
122134
123135 if len (data .encode ()) > _MAX_WS_MESSAGE_BYTES :
124- logger .warning (API_WS_TICKET_INVALID , reason = "auth_too_large" )
125- await socket .close (
126- code = _WS_CLOSE_AUTH_FAILED ,
127- reason = "Auth message too large" ,
128- )
136+ await _reject_auth (socket , "auth_too_large" , "Auth message too large" )
129137 return None
130138
131139 try :
132140 msg = json .loads (data )
133141 except json .JSONDecodeError :
134- logger .warning (API_WS_TICKET_INVALID , reason = "invalid_auth_json" )
135- await socket .close (code = _WS_CLOSE_AUTH_FAILED , reason = "Invalid auth message" )
142+ await _reject_auth (socket , "invalid_auth_json" , "Invalid auth message" )
136143 return None
137144
138145 if not isinstance (msg , dict ) or msg .get ("action" ) != "auth" :
139146 action = msg .get ("action" , "" ) if isinstance (msg , dict ) else ""
140- logger .warning (
141- API_WS_TICKET_INVALID ,
142- reason = "expected_auth_action" ,
147+ await _reject_auth (
148+ socket ,
149+ "expected_auth_action" ,
150+ "Expected auth action" ,
143151 action = str (action )[:64 ],
144152 )
145- await socket .close (code = _WS_CLOSE_AUTH_FAILED , reason = "Expected auth action" )
146153 return None
147154
148155 raw_ticket = msg .get ("ticket" )
149156 ticket : str | None = raw_ticket if isinstance (raw_ticket , str ) else None
150157 if not ticket :
151- logger .warning (API_WS_TICKET_INVALID , reason = "missing_ticket_in_auth" )
152- await socket .close (code = _WS_CLOSE_AUTH_FAILED , reason = "Missing ticket" )
158+ await _reject_auth (socket , "missing_ticket_in_auth" , "Missing ticket" )
153159 return None
154160
155161 return ticket
@@ -237,6 +243,24 @@ async def _check_ws_role(
237243 return True
238244
239245
246+ def _matches_filters (
247+ event : dict [str , Any ],
248+ channel : str ,
249+ channel_filters : dict [str , str ],
250+ ) -> bool :
251+ """Check whether the event payload matches the active channel filters."""
252+ payload = event .get ("payload" , {})
253+ if not isinstance (payload , dict ):
254+ logger .warning (
255+ API_WS_INVALID_MESSAGE ,
256+ channel = channel ,
257+ reason = "payload_not_dict" ,
258+ payload_type = type (payload ).__name__ ,
259+ )
260+ return False
261+ return all (payload .get (k ) == v for k , v in channel_filters .items ())
262+
263+
240264async def _on_event (
241265 event_data : bytes ,
242266 subscribed : set [str ],
@@ -276,25 +300,15 @@ async def _on_event(
276300 return
277301
278302 channel_filters = filters .get (channel )
279- if channel_filters :
280- payload = event .get ("payload" , {})
281- if not isinstance (payload , dict ):
282- logger .warning (
283- API_WS_INVALID_MESSAGE ,
284- channel = channel ,
285- reason = "payload_not_dict" ,
286- payload_type = type (payload ).__name__ ,
287- )
288- return
289- if not all (payload .get (k ) == v for k , v in channel_filters .items ()):
290- return
303+ if channel_filters and not _matches_filters (event , channel , channel_filters ):
304+ return
291305
292306 try :
293307 await socket .send_text (event_data .decode ("utf-8" ))
294308 except WebSocketDisconnect :
295309 logger .debug (API_WS_SEND_FAILED , reason = "client_disconnected" )
296310 except Exception :
297- logger .warning (API_WS_SEND_FAILED , exc_info = True )
311+ logger .error (API_WS_SEND_FAILED , exc_info = True )
298312 await socket .close (code = 1011 , reason = "Internal error" )
299313
300314
@@ -308,7 +322,7 @@ async def _authenticate_ws(
308322 """
309323 ticket_param = socket .query_params .get ("ticket" )
310324
311- if ticket_param :
325+ if ticket_param is not None :
312326 user = await _validate_ticket (socket )
313327 if user is None :
314328 return None
@@ -342,9 +356,9 @@ async def _setup_connection(
342356 socket : WebSocket [Any , Any , Any ],
343357 user : AuthenticatedUser ,
344358 * ,
345- accepted : bool ,
359+ already_accepted : bool ,
346360) -> tuple [ChannelsPlugin , Any ] | None :
347- """Resolve plugin, subscribe to channels , and accept the connection .
361+ """Resolve plugin, accept the connection , and subscribe to channels .
348362
349363 Returns ``(channels_plugin, subscriber)`` on success, or ``None``
350364 (socket already closed) on failure.
@@ -363,11 +377,15 @@ async def _setup_connection(
363377 await socket .close (code = 1011 , reason = "Internal error" )
364378 return None
365379
380+ socket .scope ["user" ] = user
381+ if not already_accepted :
382+ await socket .accept ()
383+
366384 try :
367385 subscriber = await channels_plugin .subscribe (list (ALL_CHANNELS ))
368386 except Exception :
369387 logger .error (
370- API_WS_SEND_FAILED ,
388+ API_WS_TRANSPORT_ERROR ,
371389 reason = "subscribe_failed" ,
372390 client = str (socket .client ),
373391 user_id = user .user_id ,
@@ -376,9 +394,6 @@ async def _setup_connection(
376394 await socket .close (code = 1011 , reason = "Internal error" )
377395 return None
378396
379- socket .scope ["user" ] = user
380- if not accepted :
381- await socket .accept ()
382397 logger .info (
383398 API_WS_CONNECTED ,
384399 client = str (socket .client ),
@@ -408,12 +423,12 @@ async def ws_handler(
408423 auth_result = await _authenticate_ws (socket )
409424 if auth_result is None :
410425 return
411- user , accepted = auth_result
426+ user , already_accepted = auth_result
412427
413428 if not await _check_ws_role (socket , user ):
414429 return
415430
416- setup = await _setup_connection (socket , user , accepted = accepted )
431+ setup = await _setup_connection (socket , user , already_accepted = already_accepted )
417432 if setup is None :
418433 return
419434 channels_plugin , subscriber = setup
0 commit comments