Skip to content

Commit 5577644

Browse files
Aurelioloclaude
andcommitted
fix(api): defer socket.accept until after plugin resolution, parametrize tests
- Move socket.accept() after ChannelsPlugin resolution and subscribe() so infrastructure failures never accept-then-immediately-close - Merge test_ws_rejects_missing_ticket and test_ws_rejects_invalid_ticket into a single parametrized test_ws_rejects_bad_ticket Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 47d5345 commit 5577644

2 files changed

Lines changed: 31 additions & 29 deletions

File tree

src/synthorg/api/controllers/ws.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,17 +215,6 @@ async def ws_handler(
215215
if not await _check_ws_role(socket, user):
216216
return
217217

218-
socket.scope["user"] = user
219-
await socket.accept()
220-
logger.info(
221-
API_WS_CONNECTED,
222-
client=str(socket.client),
223-
user_id=user.user_id,
224-
)
225-
226-
subscribed: set[str] = set()
227-
filters: dict[str, dict[str, str]] = {}
228-
229218
# Resolve ChannelsPlugin by iterating app.plugins -- the same
230219
# pattern used by get_channels_plugin() in channels.py.
231220
# Litestar's DI does not reliably inject plugin instances into
@@ -244,8 +233,23 @@ async def ws_handler(
244233
)
245234
await socket.close(code=1011, reason="Internal error")
246235
return
236+
247237
subscriber = await channels_plugin.subscribe(list(ALL_CHANNELS))
248238

239+
# Accept only after auth, role check, plugin resolution, and
240+
# subscription all succeed -- avoid accepting then immediately
241+
# closing on infrastructure failures.
242+
socket.scope["user"] = user
243+
await socket.accept()
244+
logger.info(
245+
API_WS_CONNECTED,
246+
client=str(socket.client),
247+
user_id=user.user_id,
248+
)
249+
250+
subscribed: set[str] = set()
251+
filters: dict[str, dict[str, str]] = {}
252+
249253
async def _event_callback(event_data: bytes) -> None:
250254
await _on_event(event_data, subscribed, filters, socket)
251255

tests/unit/api/controllers/test_ws.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,20 @@ def test_ws_close_codes_in_application_range(self) -> None:
281281
assert 4000 <= _WS_CLOSE_AUTH_FAILED <= 4999
282282
assert 4000 <= _WS_CLOSE_FORBIDDEN <= 4999
283283

284-
def test_ws_rejects_missing_ticket(
284+
@pytest.mark.parametrize(
285+
("url", "scenario"),
286+
[
287+
("/api/v1/ws", "missing_ticket"),
288+
("/api/v1/ws?ticket=bogus-ticket", "invalid_ticket"),
289+
],
290+
)
291+
def test_ws_rejects_bad_ticket(
285292
self,
286293
test_client: TestClient[Any],
294+
url: str,
295+
scenario: str,
287296
) -> None:
288-
"""WS connection without ?ticket= is rejected with code 4001.
297+
"""WS connection with missing or invalid ticket is rejected.
289298
290299
Verifying the close code (not just WebSocketDisconnect) ensures
291300
the rejection comes from the handler's ticket validation -- not
@@ -296,24 +305,13 @@ def test_ws_rejects_missing_ticket(
296305

297306
with (
298307
pytest.raises(WebSocketDisconnect) as exc_info,
299-
test_client.websocket_connect("/api/v1/ws"),
308+
test_client.websocket_connect(url),
300309
):
301310
pass
302-
assert exc_info.value.code == _WS_CLOSE_AUTH_FAILED
303-
304-
def test_ws_rejects_invalid_ticket(
305-
self,
306-
test_client: TestClient[Any],
307-
) -> None:
308-
"""WS connection with a bogus ticket is rejected with code 4001."""
309-
from litestar.exceptions import WebSocketDisconnect
310-
311-
with (
312-
pytest.raises(WebSocketDisconnect) as exc_info,
313-
test_client.websocket_connect("/api/v1/ws?ticket=bogus-ticket"),
314-
):
315-
pass
316-
assert exc_info.value.code == _WS_CLOSE_AUTH_FAILED
311+
assert exc_info.value.code == _WS_CLOSE_AUTH_FAILED, (
312+
f"Expected close code {_WS_CLOSE_AUTH_FAILED} for "
313+
f"{scenario}, got {exc_info.value.code}"
314+
)
317315

318316
def test_ws_accepts_valid_ticket(
319317
self,

0 commit comments

Comments
 (0)