Skip to content

Commit

Permalink
Send code 1012 on shutdown for websockets (encode#1816)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Jan 6, 2023
1 parent 4831d79 commit 23b9f05
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ plugins =

[coverage:report]
precision = 2
fail_under = 98.50
fail_under = 98.80
show_missing = true
skip_covered = true
exclude_lines =
Expand Down
34 changes: 28 additions & 6 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import typing

import httpx
import pytest
Expand Down Expand Up @@ -713,11 +714,22 @@ async def app(scope, receive, send):
message = await receive()
if message["type"] == "websocket.connect":
await send_accept_task.wait()
await send({"type": "websocket.accept"})
disconnect_message = await receive()

response: typing.Optional[httpx.Response] = None

async def websocket_session(uri):
await websockets.client.connect(uri)
nonlocal response
async with httpx.AsyncClient() as client:
response = await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-version": "13",
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
},
)

config = Config(
app=app,
Expand All @@ -731,9 +743,12 @@ async def websocket_session(uri):
websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")
)
await asyncio.sleep(0.1)
task.cancel()
send_accept_task.set()

task.cancel()
assert response is not None
assert response.status_code == 500, response.text
assert response.text == "Internal Server Error"
assert disconnect_message == {"type": "websocket.disconnect", "code": 1006}


Expand All @@ -744,6 +759,7 @@ async def test_send_close_on_server_shutdown(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
disconnect_message = {}
server_shutdown_event = asyncio.Event()

async def app(scope, receive, send):
nonlocal disconnect_message
Expand All @@ -755,10 +771,13 @@ async def app(scope, receive, send):
disconnect_message = message
break

websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None

async def websocket_session(uri):
async with websockets.client.connect(uri):
while True:
await asyncio.sleep(0.1)
nonlocal websocket
async with websockets.client.connect(uri) as ws_connection:
websocket = ws_connection
await server_shutdown_event.wait()

config = Config(
app=app,
Expand All @@ -773,7 +792,10 @@ async def websocket_session(uri):
)
await asyncio.sleep(0.1)
disconnect_message_before_shutdown = disconnect_message
server_shutdown_event.set()

assert websocket is not None
assert websocket.close_code == 1012
assert disconnect_message_before_shutdown == {}
assert disconnect_message == {"type": "websocket.disconnect", "code": 1012}
task.cancel()
Expand Down
5 changes: 4 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(
self.connect_sent = False
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.transfer_data_task: asyncio.Task = None # type: ignore[assignment]

self.ws_server: Server = Server() # type: ignore[assignment]

Expand Down Expand Up @@ -145,6 +144,10 @@ def connection_lost(self, exc: Optional[Exception]) -> None:

def shutdown(self) -> None:
self.ws_server.closing = True
if self.handshake_completed_event.is_set():
self.fail_connection(1012)
else:
self.send_500_response()
self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
Expand Down
18 changes: 7 additions & 11 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing
from urllib.parse import unquote

import h11
import wsproto
from wsproto import ConnectionType, events
from wsproto.connection import ConnectionState
Expand Down Expand Up @@ -232,17 +231,14 @@ def send_500_response(self) -> None:
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
if self.conn.connection is None:
output = self.conn.send(wsproto.events.RejectConnection(status_code=500))
else:
msg = h11.Response(
status_code=500, headers=headers, reason="Internal Server Error"
output = self.conn.send(
wsproto.events.RejectConnection(
status_code=500, headers=headers, has_body=True
)
output = self.conn.send(msg)
msg = h11.Data(data=b"Internal Server Error")
output += self.conn.send(msg)
msg = h11.EndOfMessage()
output += self.conn.send(msg)
)
output += self.conn.send(
wsproto.events.RejectData(data=b"Internal Server Error")
)
self.transport.write(output)

async def run_asgi(self) -> None:
Expand Down

0 comments on commit 23b9f05

Please sign in to comment.