From b5126b2063dbd6fa46d3e60b4a7117c99dcdbe6e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 20 Jan 2024 16:08:22 +0100 Subject: [PATCH] Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError` (#2425) * Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError` * Restrict the IOError --- starlette/websockets.py | 6 +++++- tests/test_websockets.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/starlette/websockets.py b/starlette/websockets.py index a34bc1339..084d93094 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -82,7 +82,11 @@ async def send(self, message: Message) -> None: ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED - await self._send(message) + try: + await self._send(message) + except IOError: + self.application_state = WebSocketState.DISCONNECTED + raise WebSocketDisconnect(code=1006) else: raise RuntimeError('Cannot call "send" once a close message has been sent.') diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 283dcfc78..247477404 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -255,6 +255,28 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert close_reason == "Going Away" +@pytest.mark.anyio +async def test_client_disconnect_on_send(): + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.send_text("Hello, world!") + + async def receive() -> Message: + return {"type": "websocket.connect"} + + async def send(message: Message) -> None: + if message["type"] == "websocket.accept": + return + # Simulate the exception the server would send to the application when the + # client disconnects. + raise IOError + + with pytest.raises(WebSocketDisconnect) as ctx: + await app({"type": "websocket", "path": "/"}, receive, send) + assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE + + def test_application_close(test_client_factory: Callable[..., TestClient]): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send)