Skip to content

Commit

Permalink
This commit adds handlers for the asyncio/uvloop protocol callbacks f…
Browse files Browse the repository at this point in the history
…or pause_writing and resume_writing.

These are needed for the correct functioning of built-in tcp flow-control provided by uvloop and asyncio.
This is somewhat of a breaking change, because the `write` function in user streaming callbacks now must be `await`ed.
This is necessary because it is possible now that the http protocol may be paused, and any calls to write may need to wait on an async event to be called to become unpaused.

Updated examples and tests to reflect this change.

This change does not apply to websocket connections. A change to websocket connections may be required to match this change.
  • Loading branch information
ashleysommer committed Mar 29, 2018
1 parent 8a07463 commit 75dc05f
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/request_stream/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def streaming(response):
if body is None:
break
body = body.decode('utf-8').replace('1', 'A')
response.write(body)
await response.write(body)
return stream(streaming)


Expand Down
19 changes: 11 additions & 8 deletions sanic/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def cookies(self):

class StreamingHTTPResponse(BaseHTTPResponse):
__slots__ = (
'transport', 'streaming_fn', 'status',
'protocol', 'streaming_fn', 'status',
'content_type', 'headers', '_cookies'
)

Expand All @@ -119,16 +119,17 @@ def __init__(self, streaming_fn, status=200, headers=None,
self.headers = headers or {}
self._cookies = None

def write(self, data):
async def write(self, data):
"""Writes a chunk of data to the streaming response.
:param data: bytes-ish data to be written.
"""
if type(data) != bytes:
data = self._encode_body(data)

self.transport.write(
self.protocol.push_data(
b"%x\r\n%b\r\n" % (len(data), data))
await self.protocol.drain()

async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
Expand All @@ -138,10 +139,12 @@ async def stream(
headers = self.get_headers(
version, keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout)
self.transport.write(headers)

self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self)
self.transport.write(b'0\r\n\r\n')
self.protocol.push_data(b'0\r\n\r\n')
# no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it.

def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
Expand Down Expand Up @@ -358,13 +361,13 @@ async def _streaming_fn(response):
if len(content) < 1:
break
to_send -= len(content)
response.write(content)
await response.write(content)
else:
while True:
content = await _file.read(chunk_size)
if len(content) < 1:
break
response.write(content)
await response.write(content)
finally:
await _file.close()
return # Returning from this fn closes the stream
Expand Down
22 changes: 20 additions & 2 deletions sanic/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ class HttpProtocol(asyncio.Protocol):
# connection management
'_total_request_size', '_request_timeout_handler',
'_response_timeout_handler', '_keep_alive_timeout_handler',
'_last_request_time', '_last_response_time', '_is_stream_handler')
'_last_request_time', '_last_response_time', '_is_stream_handler',
'_not_paused')


def __init__(self, *, loop, request_handler, error_handler,
signal=Signal(), connections=set(), request_timeout=60,
Expand All @@ -100,6 +102,7 @@ def __init__(self, *, loop, request_handler, error_handler,
self.request_class = request_class or Request
self.is_request_stream = is_request_stream
self._is_stream_handler = False
self._not_paused = asyncio.Event(loop=loop)
self._total_request_size = 0
self._request_timeout_handler = None
self._response_timeout_handler = None
Expand All @@ -114,6 +117,7 @@ def __init__(self, *, loop, request_handler, error_handler,
if 'requests_count' not in self.state:
self.state['requests_count'] = 0
self._debug = debug
self._not_paused.set()

@property
def keep_alive(self):
Expand Down Expand Up @@ -142,6 +146,12 @@ def connection_lost(self, exc):
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()

def pause_writing(self):
self._not_paused.clear()

def resume_writing(self):
self._not_paused.set()

def request_timeout_callback(self):
# See the docstring in the RequestTimeout exception, to see
# exactly what this timeout is checking for.
Expand Down Expand Up @@ -369,6 +379,12 @@ def write_response(self, response):
self._last_response_time = current_time
self.cleanup()

async def drain(self):
await self._not_paused.wait()

def push_data(self, data):
self.transport.write(data)

async def stream_response(self, response):
"""
Streams a response to the client asynchronously. Attaches
Expand All @@ -378,9 +394,11 @@ async def stream_response(self, response):
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None


try:
keep_alive = self.keep_alive
response.transport = self.transport
response.protocol = self
await response.stream(
self.request.version, keep_alive, self.keep_alive_timeout)
self.log_response(response)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_request_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@app.put('/_put')
Expand All @@ -100,7 +100,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@app.patch('/_patch')
Expand All @@ -117,7 +117,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

assert app.is_request_stream is True
Expand Down Expand Up @@ -177,7 +177,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

# 404
Expand Down Expand Up @@ -231,7 +231,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@bp.put('/_put')
Expand All @@ -248,7 +248,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@bp.patch('/_patch')
Expand All @@ -265,7 +265,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

app.blueprint(bp)
Expand Down Expand Up @@ -380,7 +380,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@app.get('/get')
Expand Down
33 changes: 27 additions & 6 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sanic import Sanic
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
from sanic.server import HttpProtocol
from sanic.testing import HOST, PORT
from unittest.mock import MagicMock

Expand All @@ -30,9 +31,10 @@ async def hello_route(request):


async def sample_streaming_fn(response):
response.write('foo,')
await response.write('foo,')
await asyncio.sleep(.001)
response.write('bar')
await response.write('bar')



def test_method_not_allowed():
Expand Down Expand Up @@ -168,20 +170,39 @@ def test_stream_response_includes_chunked_header():

def test_stream_response_writes_correct_content_to_transport(streaming_app):
response = StreamingHTTPResponse(sample_streaming_fn)
response.transport = MagicMock(asyncio.Transport)
response.protocol = MagicMock(HttpProtocol)
response.protocol.transport = MagicMock(asyncio.Transport)

async def mock_drain():
pass

def mock_push_data(data):
response.protocol.transport.write(data)

response.protocol.push_data = mock_push_data
response.protocol.drain = mock_drain

@streaming_app.listener('after_server_start')
async def run_stream(app, loop):
await response.stream()
assert response.transport.write.call_args_list[1][0][0] == (
# assert response.protocol.push_data.call_args_list[1][0][0] == (
# b'4\r\nfoo,\r\n'
# )
assert response.protocol.transport.write.call_args_list[1][0][0] == (
b'4\r\nfoo,\r\n'
)

assert response.transport.write.call_args_list[2][0][0] == (
# assert response.protocol.push_data.call_args_list[2][0][0] == (
# b'3\r\nbar\r\n'
# )
assert response.protocol.transport.write.call_args_list[2][0][0] == (
b'3\r\nbar\r\n'
)

assert response.transport.write.call_args_list[3][0][0] == (
# assert response.protocol.push_data.call_args_list[3][0][0] == (
# b'0\r\n\r\n'
# )
assert response.protocol.transport.write.call_args_list[3][0][0] == (
b'0\r\n\r\n'
)

Expand Down

0 comments on commit 75dc05f

Please sign in to comment.