From 72b22df0d6a13fb08025ac627e779458b238cb4e Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 1 Jul 2020 14:34:34 +0800 Subject: [PATCH 1/9] Added fix to include subprotocols from scope --- sanic/asgi.py | 2 +- sanic/websocket.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sanic/asgi.py b/sanic/asgi.py index 2ae6f36973..adefbb308a 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -98,7 +98,7 @@ def get_websocket_connection(self) -> WebSocketConnection: def create_websocket_connection( self, send: ASGISend, receive: ASGIReceive ) -> WebSocketConnection: - self._websocket_connection = WebSocketConnection(send, receive) + self._websocket_connection = WebSocketConnection(send, receive, self.scope.get('subprotocols', [])) return self._websocket_connection def add_task(self) -> None: diff --git a/sanic/websocket.py b/sanic/websocket.py index 4ae83c853e..43e3179177 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -5,7 +5,7 @@ Dict, MutableMapping, Optional, - Union, + Union, List, ) from httptools import HttpParserUpgrade # type: ignore @@ -137,9 +137,11 @@ def __init__( self, send: Callable[[ASIMessage], Awaitable[None]], receive: Callable[[], Awaitable[ASIMessage]], + subprotocols: List[str]=[], ) -> None: self._send = send self._receive = receive + self.subprotocols = subprotocols async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} @@ -164,7 +166,8 @@ async def recv(self, *args, **kwargs) -> Optional[str]: receive = recv async def accept(self) -> None: - await self._send({"type": "websocket.accept", "subprotocol": ""}) + await self._send({"type": "websocket.accept", + "subprotocol": ",".join([subprotocol for subprotocol in self.subprotocols])}) async def close(self) -> None: pass From 985c70f04644a9327f71b418f3c1f31d40ae0857 Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 1 Jul 2020 14:34:47 +0800 Subject: [PATCH 2/9] Added unit test to validate fix --- tests/test_asgi.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 05b2e96d4c..a211b2ed7a 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -208,6 +208,21 @@ async def test_websocket_receive(send, receive, message_stack): assert text == msg["text"] +@pytest.mark.asyncio +async def test_websocket_connection_with_subprotocols_communication(send, receive, message_stack): + subprotocols = ['graphql-ws', 'test'] + + ws = WebSocketConnection(send, receive, subprotocols) + await ws.accept() + + assert len(message_stack) == 1 + + message = message_stack.popleft() + assert message["type"] == "websocket.accept" + assert message['subprotocol'] == "graphql-ws,test" + assert "bytes" not in message + + def test_improper_websocket_connection(transport, send, receive): with pytest.raises(InvalidUsage): transport.get_websocket_connection() From 7f531c2f741b8de2484960c4bf6692cd13d17ce4 Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 1 Jul 2020 18:27:56 +0800 Subject: [PATCH 3/9] Changes by black --- tests/test_requests.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_requests.py b/tests/test_requests.py index 31883e3739..1ed943599a 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1955,10 +1955,7 @@ def handler(request): app.config.SERVER_NAME = "my-server" # This means default port assert app.url_for("handler", _external=True) == "http://my-server/foo" request, response = app.test_client.get("/foo") - assert ( - request.url_for("handler") - == f"http://my-server/foo" - ) + assert request.url_for("handler") == f"http://my-server/foo" app.config.SERVER_NAME = "https://my-server/path" request, response = app.test_client.get("/foo") From 493808961a11a440efe5ce52b172e7634b3bd8f3 Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 1 Jul 2020 18:38:29 +0800 Subject: [PATCH 4/9] Made changes to WebsocketConnection protocol --- sanic/asgi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sanic/asgi.py b/sanic/asgi.py index adefbb308a..5ec13cf426 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -98,7 +98,9 @@ def get_websocket_connection(self) -> WebSocketConnection: def create_websocket_connection( self, send: ASGISend, receive: ASGIReceive ) -> WebSocketConnection: - self._websocket_connection = WebSocketConnection(send, receive, self.scope.get('subprotocols', [])) + self._websocket_connection = WebSocketConnection( + send, receive, self.scope.get("subprotocols", []) + ) return self._websocket_connection def add_task(self) -> None: From c6879ac253226efafd5328e8fc773e6589760525 Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 1 Jul 2020 18:38:44 +0800 Subject: [PATCH 5/9] Linter changes --- sanic/websocket.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sanic/websocket.py b/sanic/websocket.py index 43e3179177..29608bd709 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -5,7 +5,8 @@ Dict, MutableMapping, Optional, - Union, List, + Union, + List, ) from httptools import HttpParserUpgrade # type: ignore @@ -137,11 +138,11 @@ def __init__( self, send: Callable[[ASIMessage], Awaitable[None]], receive: Callable[[], Awaitable[ASIMessage]], - subprotocols: List[str]=[], + subprotocols: Optional[List[str]] = None, ) -> None: self._send = send self._receive = receive - self.subprotocols = subprotocols + self.subprotocols = subprotocols or [] async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} @@ -166,8 +167,14 @@ async def recv(self, *args, **kwargs) -> Optional[str]: receive = recv async def accept(self) -> None: - await self._send({"type": "websocket.accept", - "subprotocol": ",".join([subprotocol for subprotocol in self.subprotocols])}) + await self._send( + { + "type": "websocket.accept", + "subprotocol": ",".join( + [subprotocol for subprotocol in self.subprotocols] + ), + } + ) async def close(self) -> None: pass From c50692d75e222e6e9cc876ffeeec246fd48d1648 Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 1 Jul 2020 18:38:53 +0800 Subject: [PATCH 6/9] Added unit tests --- tests/test_asgi.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index a211b2ed7a..0c728493f9 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -209,8 +209,40 @@ async def test_websocket_receive(send, receive, message_stack): @pytest.mark.asyncio -async def test_websocket_connection_with_subprotocols_communication(send, receive, message_stack): - subprotocols = ['graphql-ws', 'test'] +async def test_websocket_accept_with_no_subprotocols( + send, receive, message_stack +): + ws = WebSocketConnection(send, receive) + await ws.accept() + + assert len(message_stack) == 1 + + message = message_stack.popleft() + assert message["type"] == "websocket.accept" + assert message["subprotocol"] == "" + assert "bytes" not in message + + +@pytest.mark.asyncio +async def test_websocket_accept_with_subprotocol(send, receive, message_stack): + subprotocols = ["graphql-ws"] + + ws = WebSocketConnection(send, receive, subprotocols) + await ws.accept() + + assert len(message_stack) == 1 + + message = message_stack.popleft() + assert message["type"] == "websocket.accept" + assert message["subprotocol"] == "graphql-ws" + assert "bytes" not in message + + +@pytest.mark.asyncio +async def test_websocket_accept_with_multiple_subprotocols( + send, receive, message_stack +): + subprotocols = ["graphql-ws", "hello", "world"] ws = WebSocketConnection(send, receive, subprotocols) await ws.accept() @@ -219,7 +251,7 @@ async def test_websocket_connection_with_subprotocols_communication(send, receiv message = message_stack.popleft() assert message["type"] == "websocket.accept" - assert message['subprotocol'] == "graphql-ws,test" + assert message["subprotocol"] == "graphql-ws,hello,world" assert "bytes" not in message From cb864e1fae28ee8bf4070eb26fddebebe359e757 Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 15 Jul 2020 17:13:46 +0800 Subject: [PATCH 7/9] Fixing bugs in linting due to isort import checks --- sanic/compat.py | 3 ++- sanic/websocket.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sanic/compat.py b/sanic/compat.py index 28c91b97ee..11f824dd32 100644 --- a/sanic/compat.py +++ b/sanic/compat.py @@ -14,7 +14,8 @@ def get_all(self, key): use_trio = argv[0].endswith("hypercorn") and "trio" in argv if use_trio: - from trio import open_file as open_async, Path # type: ignore + from trio import Path + from trio import open_file as open_async # type: ignore def stat_async(path): return Path(path).stat() diff --git a/sanic/websocket.py b/sanic/websocket.py index 29608bd709..9443b70429 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -3,10 +3,10 @@ Awaitable, Callable, Dict, + List, MutableMapping, Optional, Union, - List, ) from httptools import HttpParserUpgrade # type: ignore From 6ae1cecb8769d40867f3dc878271c674602a911e Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 15 Jul 2020 17:15:01 +0800 Subject: [PATCH 8/9] Reverting compat import changes --- sanic/compat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sanic/compat.py b/sanic/compat.py index 11f824dd32..17589d1af6 100644 --- a/sanic/compat.py +++ b/sanic/compat.py @@ -14,8 +14,7 @@ def get_all(self, key): use_trio = argv[0].endswith("hypercorn") and "trio" in argv if use_trio: - from trio import Path - from trio import open_file as open_async # type: ignore + from trio import Path, open_file as open_async # type: ignore def stat_async(path): return Path(path).stat() From 5e81463d2b7ff82e3b5bb68eb9e5bca5b3c63a8b Mon Sep 17 00:00:00 2001 From: David Lee Date: Wed, 15 Jul 2020 17:16:03 +0800 Subject: [PATCH 9/9] Fixing linter errors in compat.py --- sanic/compat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sanic/compat.py b/sanic/compat.py index 17589d1af6..11f824dd32 100644 --- a/sanic/compat.py +++ b/sanic/compat.py @@ -14,7 +14,8 @@ def get_all(self, key): use_trio = argv[0].endswith("hypercorn") and "trio" in argv if use_trio: - from trio import Path, open_file as open_async # type: ignore + from trio import Path + from trio import open_file as open_async # type: ignore def stat_async(path): return Path(path).stat()