Skip to content

Commit

Permalink
Raise WebSocketDisconnect when WebSocket.send() excepts IOError (
Browse files Browse the repository at this point in the history
…#2425)

* Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError`

* Restrict the IOError
  • Loading branch information
Kludex authored Jan 20, 2024
1 parent 3ae161e commit b5126b2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
6 changes: 5 additions & 1 deletion starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ async def send(self, message: Message) -> None:
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
try:
await self._send(message)
except IOError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')

Expand Down
22 changes: 22 additions & 0 deletions tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,28 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert close_reason == "Going Away"


@pytest.mark.anyio
async def test_client_disconnect_on_send():
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.send_text("Hello, world!")

async def receive() -> Message:
return {"type": "websocket.connect"}

async def send(message: Message) -> None:
if message["type"] == "websocket.accept":
return
# Simulate the exception the server would send to the application when the
# client disconnects.
raise IOError

with pytest.raises(WebSocketDisconnect) as ctx:
await app({"type": "websocket", "path": "/"}, receive, send)
assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE


def test_application_close(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
Expand Down

0 comments on commit b5126b2

Please sign in to comment.