Skip to content

Commit

Permalink
Timeout close handshake WebSocket connections
Browse files Browse the repository at this point in the history
This ensures that WebSocket connections will close even if the client
doesn't respond with a close frame (after a duration equal to the keep
alive timeout).
  • Loading branch information
pgjones committed Jul 8, 2019
1 parent 13bbb7f commit ccc233c
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion hypercorn/protocol/h11.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(

@property
def idle(self) -> bool:
return self.stream is None
return self.stream is None or self.stream.idle

async def initiate(self) -> None:
pass
Expand Down
2 changes: 1 addition & 1 deletion hypercorn/protocol/h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(

@property
def idle(self) -> bool:
return len(self.streams) == 0
return len(self.streams) == 0 or all(stream.idle for stream in self.streams.values())

async def initiate(
self, headers: Optional[List[Tuple[bytes, bytes]]] = None, settings: Optional[str] = None
Expand Down
4 changes: 4 additions & 0 deletions hypercorn/protocol/http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
self.state = ASGIHTTPState.REQUEST
self.stream_id = stream_id

@property
def idle(self) -> bool:
return False

async def handle(self, event: Event) -> None:
if isinstance(event, Request):
path, _, query_string = event.raw_path.partition(b"?")
Expand Down
10 changes: 7 additions & 3 deletions hypercorn/protocol/ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def __init__(
self.connection: Connection
self.handshake: Handshake

@property
def idle(self) -> bool:
return self.state in {ASGIWebsocketState.CLOSED, ASGIWebsocketState.HTTPCLOSED}

async def handle(self, event: Event) -> None:
if isinstance(event, Request):
self.handshake = Handshake(event.headers, event.http_version)
Expand Down Expand Up @@ -247,14 +251,14 @@ async def app_send(self, message: Optional[dict]) -> None:
elif (
message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE
):
await self._send_error_response(403)
self.state = ASGIWebsocketState.HTTPCLOSED
await self._send_error_response(403)
elif message["type"] == "websocket.close":
self.state = ASGIWebsocketState.CLOSED
await self._send_wsproto_event(
CloseConnection(code=int(message.get("code", CloseReason.NORMAL_CLOSURE)))
)
await self.send(EndData(stream_id=self.stream_id))
self.state = ASGIWebsocketState.CLOSED
else:
raise UnexpectedMessage(self.state, message["type"])

Expand Down Expand Up @@ -311,6 +315,6 @@ async def _send_rejection(self, message: dict) -> None:
if not body_suppressed:
await self.send(Body(stream_id=self.stream_id, data=bytes(message.get("body", b""))))
if not message.get("more_body", False):
await self.send(EndBody(stream_id=self.stream_id))
self.state = ASGIWebsocketState.HTTPCLOSED
await self.send(EndBody(stream_id=self.stream_id))
await self.config.log.access(self.scope, self.response, time() - self.start_time)
4 changes: 4 additions & 0 deletions tests/protocol/test_http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,7 @@ async def test_send_invalid_message(
with pytest.raises((TypeError, ValueError)):
await stream.app_send({"type": "http.response.start", "headers": headers, "status": status})
await stream.app_send({"type": "http.response.body", "body": body})


def test_stream_idle(stream: HTTPStream) -> None:
assert stream.idle is False
14 changes: 14 additions & 0 deletions tests/protocol/test_ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,17 @@ async def test_send_invalid_http_message(
{"type": "websocket.http.response.start", "headers": headers, "status": status}
)
await stream.app_send({"type": "websocket.http.response.body", "body": body})


@pytest.mark.parametrize(
"state, idle",
[
(state, False)
for state in ASGIWebsocketState
if state not in {ASGIWebsocketState.CLOSED, ASGIWebsocketState.HTTPCLOSED}
]
+ [(ASGIWebsocketState.CLOSED, True), (ASGIWebsocketState.HTTPCLOSED, True)],
)
def test_stream_idle(stream: WSStream, state: ASGIWebsocketState, idle: bool) -> None:
stream.state = state
assert stream.idle is idle

0 comments on commit ccc233c

Please sign in to comment.