Skip to content

Commit 5aa49ce

Browse files
committed
fix(api): address second-round CodeRabbit findings
- Use None-check instead of falsy-or for auth.exclude_paths (preserves explicitly empty tuples) - Use math.ceil for expires_in to prevent floor-to-zero on sub-second TTLs - Log ticket limit exceeded before raising ConflictError - Guard against non-dict payload in WS event filter matching - Split _handle_message into _parse_ws_message + _validate_ws_fields + dispatch (all functions now <50 lines) - Add @pytest.mark.timeout(30) to all new test classes - Add concurrent consumer test for single-use ticket guarantee
1 parent 3ccdad3 commit 5aa49ce

6 files changed

Lines changed: 80 additions & 22 deletions

File tree

src/synthorg/api/app.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -649,12 +649,16 @@ def _build_middleware(api_config: ApiConfig) -> list[Middleware]:
649649
auth = api_config.auth
650650
prefix = api_config.api_prefix
651651
ws_path = f"^{prefix}/ws$"
652-
exclude_paths = auth.exclude_paths or (
653-
f"^{prefix}/health$",
654-
"^/docs",
655-
"^/api$",
656-
f"^{prefix}/auth/setup$",
657-
f"^{prefix}/auth/login$",
652+
exclude_paths = (
653+
auth.exclude_paths
654+
if auth.exclude_paths is not None
655+
else (
656+
f"^{prefix}/health$",
657+
"^/docs",
658+
"^/api$",
659+
f"^{prefix}/auth/setup$",
660+
f"^{prefix}/auth/login$",
661+
)
658662
)
659663
# Always ensure the WS upgrade path is excluded — the WS handler
660664
# performs its own ticket-based auth, so the JWT middleware must

src/synthorg/api/auth/controller.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Authentication controller — setup, login, password change, me, ws-ticket."""
22

3+
import math
34
import uuid
45
from datetime import UTC, datetime
56
from typing import Any, Self
@@ -459,14 +460,19 @@ async def ws_ticket(
459460
try:
460461
ticket = app_state.ticket_store.create(ws_user)
461462
except TicketLimitExceededError:
463+
logger.warning(
464+
API_AUTH_FAILED,
465+
reason="ws_ticket_limit_exceeded",
466+
user_id=auth_user.user_id,
467+
)
462468
msg = "Too many pending tickets — wait for existing tickets to expire"
463469
raise ConflictError(msg) # noqa: B904
464470

465471
return Response(
466472
content=ApiResponse(
467473
data=WsTicketResponse(
468474
ticket=ticket,
469-
expires_in=int(app_state.ticket_store.ttl_seconds),
475+
expires_in=max(1, math.ceil(app_state.ticket_store.ttl_seconds)),
470476
),
471477
),
472478
)

src/synthorg/api/controllers/ws.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ async def _on_event(
153153
channel_filters = filters.get(channel)
154154
if channel_filters:
155155
payload = event.get("payload", {})
156+
if not isinstance(payload, dict):
157+
return
156158
if not all(payload.get(k) == v for k, v in channel_filters.items()):
157159
return
158160

@@ -236,21 +238,10 @@ async def _receive_loop(
236238
raise
237239

238240

239-
def _handle_message( # noqa: PLR0911
241+
def _parse_ws_message(
240242
data: str,
241-
subscribed: set[str],
242-
filters: dict[str, dict[str, str]],
243-
) -> str:
244-
"""Parse and dispatch a single client message.
245-
246-
Args:
247-
data: Raw JSON string from the client.
248-
subscribed: Mutable set of subscribed channel names.
249-
filters: Mutable per-channel payload filters.
250-
251-
Returns:
252-
JSON acknowledgement or error string.
253-
"""
243+
) -> dict[str, Any] | str:
244+
"""Parse raw JSON from the client, returning a dict or an error string."""
254245
if len(data.encode()) > _MAX_WS_MESSAGE_BYTES:
255246
return json.dumps({"error": "Message too large"})
256247

@@ -271,7 +262,18 @@ def _handle_message( # noqa: PLR0911
271262
if not isinstance(msg, dict):
272263
return json.dumps({"error": "Expected JSON object"})
273264

274-
action = msg.get("action")
265+
return msg
266+
267+
268+
def _validate_ws_fields(
269+
msg: dict[str, Any],
270+
) -> tuple[str, list[str], dict[str, Any]] | str:
271+
"""Extract and validate action, channels, and filters from a parsed message.
272+
273+
Returns ``(action, channels, client_filters)`` on success, or a
274+
JSON error string on validation failure.
275+
"""
276+
action = str(msg.get("action", ""))
275277
channels = msg.get("channels", [])
276278
client_filters = msg.get("filters", {})
277279

@@ -280,6 +282,25 @@ def _handle_message( # noqa: PLR0911
280282
if not isinstance(client_filters, dict):
281283
return json.dumps({"error": "filters must be an object"})
282284

285+
return (action, channels, client_filters)
286+
287+
288+
def _handle_message(
289+
data: str,
290+
subscribed: set[str],
291+
filters: dict[str, dict[str, str]],
292+
) -> str:
293+
"""Parse, validate, and dispatch a single client message."""
294+
parsed = _parse_ws_message(data)
295+
if isinstance(parsed, str):
296+
return parsed
297+
298+
fields = _validate_ws_fields(parsed)
299+
if isinstance(fields, str):
300+
return fields
301+
302+
action, channels, client_filters = fields
303+
283304
if action == "subscribe":
284305
return _handle_subscribe(channels, client_filters, subscribed, filters)
285306

tests/unit/api/auth/test_controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def test_rejects_unknown_user_type(self) -> None:
327327
require_password_changed(connection, None)
328328

329329

330+
@pytest.mark.timeout(30)
330331
@pytest.mark.unit
331332
class TestWsTicket:
332333
def test_ws_ticket_returns_ticket_and_expires_in(

tests/unit/api/auth/test_ticket_store.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def _make_user(
2525
)
2626

2727

28+
@pytest.mark.timeout(30)
2829
@pytest.mark.unit
2930
class TestWsTicketStoreCreate:
3031
"""Tests for ticket creation."""
@@ -90,6 +91,7 @@ def test_per_user_ticket_cap_different_users(self) -> None:
9091
assert isinstance(ticket, str)
9192

9293

94+
@pytest.mark.timeout(30)
9395
@pytest.mark.unit
9496
class TestWsTicketStoreValidateAndConsume:
9597
"""Tests for ticket validation and consumption."""
@@ -118,6 +120,28 @@ def test_validate_and_consume_single_use(self) -> None:
118120
assert first is not None
119121
assert second is None
120122

123+
def test_validate_and_consume_single_use_concurrent(self) -> None:
124+
"""Exactly one concurrent consumer wins the ticket."""
125+
import threading
126+
from concurrent.futures import ThreadPoolExecutor
127+
128+
store = WsTicketStore()
129+
user = _make_user()
130+
ticket = store.create(user)
131+
132+
barrier = threading.Barrier(10)
133+
134+
def consume() -> AuthenticatedUser | None:
135+
barrier.wait()
136+
return store.validate_and_consume(ticket)
137+
138+
with ThreadPoolExecutor(max_workers=10) as pool:
139+
results = list(pool.map(lambda _: consume(), range(10)))
140+
141+
winners = [r for r in results if r is not None]
142+
assert len(winners) == 1
143+
assert winners[0].user_id == user.user_id
144+
121145
def test_validate_and_consume_expired(self) -> None:
122146
store = WsTicketStore(ttl_seconds=10.0)
123147
user = _make_user()
@@ -203,6 +227,7 @@ def test_custom_ttl_expired(self) -> None:
203227
assert result is None
204228

205229

230+
@pytest.mark.timeout(30)
206231
@pytest.mark.unit
207232
class TestWsTicketStoreCleanup:
208233
"""Tests for expired ticket cleanup."""

tests/unit/api/controllers/test_ws.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def test_non_dict_json_returns_error(self, value: object) -> None:
189189
assert data["error"] == "Expected JSON object"
190190

191191

192+
@pytest.mark.timeout(30)
192193
@pytest.mark.unit
193194
class TestWsTicketAuth:
194195
"""Tests for ticket-based WebSocket authentication logic.

0 commit comments

Comments
 (0)