diff --git a/examples/request_stream/server.py b/examples/request_stream/server.py index e53a224ce8..d3d35aef2c 100644 --- a/examples/request_stream/server.py +++ b/examples/request_stream/server.py @@ -30,7 +30,7 @@ async def streaming(response): if body is None: break body = body.decode('utf-8').replace('1', 'A') - response.write(body) + await response.write(body) return stream(streaming) diff --git a/sanic/response.py b/sanic/response.py index 9349ce81f9..19c7c88963 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -107,7 +107,7 @@ def cookies(self): class StreamingHTTPResponse(BaseHTTPResponse): __slots__ = ( - 'transport', 'streaming_fn', 'status', + 'protocol', 'streaming_fn', 'status', 'content_type', 'headers', '_cookies' ) @@ -119,7 +119,7 @@ def __init__(self, streaming_fn, status=200, headers=None, self.headers = headers or {} self._cookies = None - def write(self, data): + async def write(self, data): """Writes a chunk of data to the streaming response. :param data: bytes-ish data to be written. @@ -127,8 +127,9 @@ def write(self, data): if type(data) != bytes: data = self._encode_body(data) - self.transport.write( + self.protocol.push_data( b"%x\r\n%b\r\n" % (len(data), data)) + await self.protocol.drain() async def stream( self, version="1.1", keep_alive=False, keep_alive_timeout=None): @@ -138,10 +139,12 @@ async def stream( headers = self.get_headers( version, keep_alive=keep_alive, keep_alive_timeout=keep_alive_timeout) - self.transport.write(headers) - + self.protocol.push_data(headers) + await self.protocol.drain() await self.streaming_fn(self) - self.transport.write(b'0\r\n\r\n') + self.protocol.push_data(b'0\r\n\r\n') + # no need to await drain here after this write, because it is the + # very last thing we write and nothing needs to wait for it. def get_headers( self, version="1.1", keep_alive=False, keep_alive_timeout=None): @@ -358,13 +361,13 @@ async def _streaming_fn(response): if len(content) < 1: break to_send -= len(content) - response.write(content) + await response.write(content) else: while True: content = await _file.read(chunk_size) if len(content) < 1: break - response.write(content) + await response.write(content) finally: await _file.close() return # Returning from this fn closes the stream diff --git a/sanic/server.py b/sanic/server.py index 15ae4708e0..ab84178c0b 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -73,7 +73,9 @@ class HttpProtocol(asyncio.Protocol): # connection management '_total_request_size', '_request_timeout_handler', '_response_timeout_handler', '_keep_alive_timeout_handler', - '_last_request_time', '_last_response_time', '_is_stream_handler') + '_last_request_time', '_last_response_time', '_is_stream_handler', + '_not_paused') + def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections=set(), request_timeout=60, @@ -100,6 +102,7 @@ def __init__(self, *, loop, request_handler, error_handler, self.request_class = request_class or Request self.is_request_stream = is_request_stream self._is_stream_handler = False + self._not_paused = asyncio.Event(loop=loop) self._total_request_size = 0 self._request_timeout_handler = None self._response_timeout_handler = None @@ -114,6 +117,7 @@ def __init__(self, *, loop, request_handler, error_handler, if 'requests_count' not in self.state: self.state['requests_count'] = 0 self._debug = debug + self._not_paused.set() @property def keep_alive(self): @@ -142,6 +146,12 @@ def connection_lost(self, exc): if self._keep_alive_timeout_handler: self._keep_alive_timeout_handler.cancel() + def pause_writing(self): + self._not_paused.clear() + + def resume_writing(self): + self._not_paused.set() + def request_timeout_callback(self): # See the docstring in the RequestTimeout exception, to see # exactly what this timeout is checking for. @@ -369,6 +379,12 @@ def write_response(self, response): self._last_response_time = current_time self.cleanup() + async def drain(self): + await self._not_paused.wait() + + def push_data(self, data): + self.transport.write(data) + async def stream_response(self, response): """ Streams a response to the client asynchronously. Attaches @@ -378,9 +394,11 @@ async def stream_response(self, response): if self._response_timeout_handler: self._response_timeout_handler.cancel() self._response_timeout_handler = None + + try: keep_alive = self.keep_alive - response.transport = self.transport + response.protocol = self await response.stream( self.request.version, keep_alive, self.keep_alive_timeout) self.log_response(response) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 4ca4e44e97..b14aa5191f 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -83,7 +83,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @app.put('/_put') @@ -100,7 +100,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @app.patch('/_patch') @@ -117,7 +117,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) assert app.is_request_stream is True @@ -177,7 +177,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) # 404 @@ -231,7 +231,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @bp.put('/_put') @@ -248,7 +248,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @bp.patch('/_patch') @@ -265,7 +265,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) app.blueprint(bp) @@ -380,7 +380,7 @@ async def streaming(response): body = await request.stream.get() if body is None: break - response.write(body.decode('utf-8')) + await response.write(body.decode('utf-8')) return stream(streaming) @app.get('/get') diff --git a/tests/test_response.py b/tests/test_response.py index 5b3f76de0b..7cfc1ebb7a 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -10,6 +10,7 @@ from sanic import Sanic from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json +from sanic.server import HttpProtocol from sanic.testing import HOST, PORT from unittest.mock import MagicMock @@ -30,9 +31,10 @@ async def hello_route(request): async def sample_streaming_fn(response): - response.write('foo,') + await response.write('foo,') await asyncio.sleep(.001) - response.write('bar') + await response.write('bar') + def test_method_not_allowed(): @@ -168,20 +170,39 @@ def test_stream_response_includes_chunked_header(): def test_stream_response_writes_correct_content_to_transport(streaming_app): response = StreamingHTTPResponse(sample_streaming_fn) - response.transport = MagicMock(asyncio.Transport) + response.protocol = MagicMock(HttpProtocol) + response.protocol.transport = MagicMock(asyncio.Transport) + + async def mock_drain(): + pass + + def mock_push_data(data): + response.protocol.transport.write(data) + + response.protocol.push_data = mock_push_data + response.protocol.drain = mock_drain @streaming_app.listener('after_server_start') async def run_stream(app, loop): await response.stream() - assert response.transport.write.call_args_list[1][0][0] == ( + # assert response.protocol.push_data.call_args_list[1][0][0] == ( + # b'4\r\nfoo,\r\n' + # ) + assert response.protocol.transport.write.call_args_list[1][0][0] == ( b'4\r\nfoo,\r\n' ) - assert response.transport.write.call_args_list[2][0][0] == ( + # assert response.protocol.push_data.call_args_list[2][0][0] == ( + # b'3\r\nbar\r\n' + # ) + assert response.protocol.transport.write.call_args_list[2][0][0] == ( b'3\r\nbar\r\n' ) - assert response.transport.write.call_args_list[3][0][0] == ( + # assert response.protocol.push_data.call_args_list[3][0][0] == ( + # b'0\r\n\r\n' + # ) + assert response.protocol.transport.write.call_args_list[3][0][0] == ( b'0\r\n\r\n' )