From 3f7c9ea3f533af529aae2b4ef6d78432d8776fc2 Mon Sep 17 00:00:00 2001 From: Andrew Scott Date: Thu, 27 Aug 2020 00:22:02 -0700 Subject: [PATCH] feat: fixes exception due to unread bytes in stream (#1897) * feat: fixes exception due to unread bytes in stream * feat: additonal unit tests to cover changes * fix: automated changes by `make fix-import` * fix: additonal changes by `make fix-import` Co-authored-by: Adam Hopkins --- sanic/response.py | 7 +++++-- sanic/router.py | 2 +- sanic/server.py | 13 +++++++------ sanic/testing.py | 4 +++- sanic/worker.py | 2 +- tests/test_keep_alive_timeout.py | 4 ++-- tests/test_request_data.py | 7 +++---- tests/test_request_stream.py | 18 ++++++++++++++++++ 8 files changed, 40 insertions(+), 17 deletions(-) diff --git a/sanic/response.py b/sanic/response.py index 2ebf0046a0..24033336c2 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -42,7 +42,7 @@ def get_headers( body=b"", ): """.. deprecated:: 20.3: - This function is not public API and will be removed.""" + This function is not public API and will be removed.""" # self.headers get priority over content_type if self.content_type and "Content-Type" not in self.headers: @@ -249,7 +249,10 @@ def raw( :param content_type: the content type (string) of the response. """ return HTTPResponse( - body=body, status=status, headers=headers, content_type=content_type, + body=body, + status=status, + headers=headers, + content_type=content_type, ) diff --git a/sanic/router.py b/sanic/router.py index ab6e3cefd7..a608f1a24d 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -452,7 +452,7 @@ def _get(self, url, method, host): return route_handler, [], kwargs, route.uri, route.name def is_stream_handler(self, request): - """ Handler for request is stream or not. + """Handler for request is stream or not. :param request: Request object :return: bool """ diff --git a/sanic/server.py b/sanic/server.py index c4e08e76e8..2e27d80b5a 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -418,12 +418,13 @@ async def body_append(self, body): async def stream_append(self): while self._body_chunks: body = self._body_chunks.popleft() - if self.request.stream.is_full(): - self.transport.pause_reading() - await self.request.stream.put(body) - self.transport.resume_reading() - else: - await self.request.stream.put(body) + if self.request: + if self.request.stream.is_full(): + self.transport.pause_reading() + await self.request.stream.put(body) + self.transport.resume_reading() + else: + await self.request.stream.put(body) def on_message_complete(self): # Entire request (headers and whole body) is received. diff --git a/sanic/testing.py b/sanic/testing.py index faabdfd1ca..020b3c11ae 100644 --- a/sanic/testing.py +++ b/sanic/testing.py @@ -103,7 +103,9 @@ async def error_handler(request, exception): if self.port: server_kwargs = dict( - host=host or self.host, port=self.port, **server_kwargs, + host=host or self.host, + port=self.port, + **server_kwargs, ) host, port = host or self.host, self.port else: diff --git a/sanic/worker.py b/sanic/worker.py index b217b992cb..765f26f7b3 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -174,7 +174,7 @@ async def _check_alive(self): @staticmethod def _create_ssl_context(cfg): - """ Creates SSLContext instance for usage in asyncio.create_server. + """Creates SSLContext instance for usage in asyncio.create_server. See ssl.SSLSocket.__init__ for more details. """ ctx = ssl.SSLContext(cfg.ssl_version) diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index 58385becaa..4cc24f53cf 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -244,8 +244,8 @@ async def handler3(request): def test_keep_alive_timeout_reuse(): """If the server keep-alive timeout and client keep-alive timeout are - both longer than the delay, the client _and_ server will successfully - reuse the existing connection.""" + both longer than the delay, the client _and_ server will successfully + reuse the existing connection.""" try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) diff --git a/tests/test_request_data.py b/tests/test_request_data.py index d9a6351fdb..f5bfabda83 100644 --- a/tests/test_request_data.py +++ b/tests/test_request_data.py @@ -46,8 +46,8 @@ def modify(request, response): invalid = str(e) j = loads(response.body) - j['response_mw_valid'] = user - j['response_mw_invalid'] = invalid + j["response_mw_valid"] = user + j["response_mw_invalid"] = invalid return json(j) request, response = app.test_client.get("/") @@ -59,8 +59,7 @@ def modify(request, response): "has_missing": False, "invalid": "'types.SimpleNamespace' object has no attribute 'missing'", "response_mw_valid": "sanic", - "response_mw_invalid": - "'types.SimpleNamespace' object has no attribute 'missing'" + "response_mw_invalid": "'types.SimpleNamespace' object has no attribute 'missing'", } diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 972b2e1a61..ff298868a8 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -1,4 +1,5 @@ import pytest +import asyncio from sanic.blueprints import Blueprint from sanic.exceptions import HeaderExpectationFailed @@ -6,6 +7,7 @@ from sanic.response import json, stream, text from sanic.views import CompositionView, HTTPMethodView from sanic.views import stream as stream_decorator +from sanic.server import HttpProtocol data = "abc" * 1_000_000 @@ -337,6 +339,22 @@ async def post(request, id): assert "Method GET not allowed for URL /post/random_id" in response.text +@pytest.mark.asyncio +async def test_request_stream_unread(app): + """ensure no error is raised when leaving unread bytes in byte-buffer""" + + err = None + protocol = HttpProtocol(loop=asyncio.get_event_loop(), app=app) + try: + protocol.request = None + protocol._body_chunks.append("this is a test") + await protocol.stream_append() + except AttributeError as e: + err = e + + assert err is None and not protocol._body_chunks + + def test_request_stream_blueprint(app): """for self.is_request_stream = True""" bp = Blueprint("test_blueprint_request_stream_blueprint")