Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket subprotocol #1887

Merged
merged 11 commits into from
Jul 29, 2020
4 changes: 3 additions & 1 deletion sanic/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def get_websocket_connection(self) -> WebSocketConnection:
def create_websocket_connection(
self, send: ASGISend, receive: ASGIReceive
) -> WebSocketConnection:
self._websocket_connection = WebSocketConnection(send, receive)
self._websocket_connection = WebSocketConnection(
send, receive, self.scope.get("subprotocols", [])
)
return self._websocket_connection

def add_task(self) -> None:
Expand Down
12 changes: 11 additions & 1 deletion sanic/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Awaitable,
Callable,
Dict,
List,
MutableMapping,
Optional,
Union,
Expand Down Expand Up @@ -137,9 +138,11 @@ def __init__(
self,
send: Callable[[ASIMessage], Awaitable[None]],
receive: Callable[[], Awaitable[ASIMessage]],
subprotocols: Optional[List[str]] = None,
) -> None:
self._send = send
self._receive = receive
self.subprotocols = subprotocols or []

async def send(self, data: Union[str, bytes], *args, **kwargs) -> None:
message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"}
Expand All @@ -164,7 +167,14 @@ async def recv(self, *args, **kwargs) -> Optional[str]:
receive = recv

async def accept(self) -> None:
await self._send({"type": "websocket.accept", "subprotocol": ""})
await self._send(
{
"type": "websocket.accept",
"subprotocol": ",".join(
[subprotocol for subprotocol in self.subprotocols]
),
}
)

async def close(self) -> None:
pass
47 changes: 47 additions & 0 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,53 @@ async def test_websocket_receive(send, receive, message_stack):
assert text == msg["text"]


@pytest.mark.asyncio
async def test_websocket_accept_with_no_subprotocols(
send, receive, message_stack
):
ws = WebSocketConnection(send, receive)
await ws.accept()

assert len(message_stack) == 1

message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == ""
assert "bytes" not in message


@pytest.mark.asyncio
async def test_websocket_accept_with_subprotocol(send, receive, message_stack):
subprotocols = ["graphql-ws"]

ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept()

assert len(message_stack) == 1

message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == "graphql-ws"
assert "bytes" not in message


@pytest.mark.asyncio
async def test_websocket_accept_with_multiple_subprotocols(
send, receive, message_stack
):
subprotocols = ["graphql-ws", "hello", "world"]

ws = WebSocketConnection(send, receive, subprotocols)
await ws.accept()

assert len(message_stack) == 1

message = message_stack.popleft()
assert message["type"] == "websocket.accept"
assert message["subprotocol"] == "graphql-ws,hello,world"
assert "bytes" not in message


def test_improper_websocket_connection(transport, send, receive):
with pytest.raises(InvalidUsage):
transport.get_websocket_connection()
Expand Down