Skip to content

Commit a080bf5

Browse files
committed
Additional headers for WS accept message.
1 parent c8b9581 commit a080bf5

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

docs/websockets.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
Starlette includes a `WebSocket` class that fulfils a similar role
32
to the HTTP request, but that allows sending and receiving data on a websocket.
43

@@ -51,7 +50,7 @@ For example: `websocket.path_params['username']`
5150

5251
### Accepting the connection
5352

54-
* `await websocket.accept(subprotocol=None)`
53+
* `await websocket.accept(subprotocol=None, headers=None)`
5554

5655
### Sending data
5756

starlette/testclient.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def __init__(
296296
self.app = app
297297
self.scope = scope
298298
self.accepted_subprotocol = None
299+
self.additional_headers = None
299300
self.portal_factory = portal_factory
300301
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
301302
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
@@ -313,6 +314,7 @@ def __enter__(self) -> "WebSocketTestSession":
313314
self.exit_stack.close()
314315
raise
315316
self.accepted_subprotocol = message.get("subprotocol", None)
317+
self.additional_headers = message.get("headers", None)
316318
return self
317319

318320
def __exit__(self, *args: typing.Any) -> None:

starlette/websockets.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,17 @@ async def send(self, message: Message) -> None:
6969
else:
7070
raise RuntimeError('Cannot call "send" once a close message has been sent.')
7171

72-
async def accept(self, subprotocol: str = None) -> None:
72+
async def accept(
73+
self,
74+
subprotocol: str = None,
75+
headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None,
76+
) -> None:
7377
if self.client_state == WebSocketState.CONNECTING:
7478
# If we haven't yet seen the 'connect' message, then wait for it first.
7579
await self.receive()
76-
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
80+
await self.send(
81+
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
82+
)
7783

7884
def _raise_on_disconnect(self, message: Message) -> None:
7985
if message["type"] == "websocket.disconnect":

tests/test_websockets.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,20 @@ async def asgi(receive, send):
301301
assert websocket.accepted_subprotocol == "wamp"
302302

303303

304+
def test_additional_headers(test_client_factory):
305+
def app(scope):
306+
async def asgi(receive, send):
307+
websocket = WebSocket(scope, receive=receive, send=send)
308+
await websocket.accept(headers=[(b"additional", b"header")])
309+
await websocket.close()
310+
311+
return asgi
312+
313+
client = test_client_factory(app)
314+
with client.websocket_connect("/") as websocket:
315+
websocket.additional_headers = [(b"additional", b"header")]
316+
317+
304318
def test_websocket_exception(test_client_factory):
305319
def app(scope):
306320
async def asgi(receive, send):

0 commit comments

Comments
 (0)