Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pausable response streams #1179

Merged
merged 4 commits into from
Aug 19, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/request_stream/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def streaming(response):
if body is None:
break
body = body.decode('utf-8').replace('1', 'A')
response.write(body)
await response.write(body)
return stream(streaming)


Expand Down
19 changes: 11 additions & 8 deletions sanic/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def cookies(self):

class StreamingHTTPResponse(BaseHTTPResponse):
__slots__ = (
'transport', 'streaming_fn', 'status',
'protocol', 'streaming_fn', 'status',
'content_type', 'headers', '_cookies'
)

Expand All @@ -119,16 +119,17 @@ def __init__(self, streaming_fn, status=200, headers=None,
self.headers = headers or {}
self._cookies = None

def write(self, data):
async def write(self, data):
"""Writes a chunk of data to the streaming response.

:param data: bytes-ish data to be written.
"""
if type(data) != bytes:
data = self._encode_body(data)

self.transport.write(
self.protocol.push_data(
b"%x\r\n%b\r\n" % (len(data), data))
await self.protocol.drain()

async def stream(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
Expand All @@ -138,10 +139,12 @@ async def stream(
headers = self.get_headers(
version, keep_alive=keep_alive,
keep_alive_timeout=keep_alive_timeout)
self.transport.write(headers)

self.protocol.push_data(headers)
await self.protocol.drain()
await self.streaming_fn(self)
self.transport.write(b'0\r\n\r\n')
self.protocol.push_data(b'0\r\n\r\n')
# no need to await drain here after this write, because it is the
# very last thing we write and nothing needs to wait for it.

def get_headers(
self, version="1.1", keep_alive=False, keep_alive_timeout=None):
Expand Down Expand Up @@ -358,13 +361,13 @@ async def _streaming_fn(response):
if len(content) < 1:
break
to_send -= len(content)
response.write(content)
await response.write(content)
else:
while True:
content = await _file.read(chunk_size)
if len(content) < 1:
break
response.write(content)
await response.write(content)
finally:
await _file.close()
return # Returning from this fn closes the stream
Expand Down
20 changes: 18 additions & 2 deletions sanic/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class HttpProtocol(asyncio.Protocol):
# connection management
'_total_request_size', '_request_timeout_handler',
'_response_timeout_handler', '_keep_alive_timeout_handler',
'_last_request_time', '_last_response_time', '_is_stream_handler')
'_last_request_time', '_last_response_time', '_is_stream_handler',
'_not_paused')

def __init__(self, *, loop, request_handler, error_handler,
signal=Signal(), connections=set(), request_timeout=60,
Expand All @@ -100,6 +101,7 @@ def __init__(self, *, loop, request_handler, error_handler,
self.request_class = request_class or Request
self.is_request_stream = is_request_stream
self._is_stream_handler = False
self._not_paused = asyncio.Event(loop=loop)
self._total_request_size = 0
self._request_timeout_handler = None
self._response_timeout_handler = None
Expand All @@ -114,6 +116,7 @@ def __init__(self, *, loop, request_handler, error_handler,
if 'requests_count' not in self.state:
self.state['requests_count'] = 0
self._debug = debug
self._not_paused.set()

@property
def keep_alive(self):
Expand Down Expand Up @@ -142,6 +145,12 @@ def connection_lost(self, exc):
if self._keep_alive_timeout_handler:
self._keep_alive_timeout_handler.cancel()

def pause_writing(self):
self._not_paused.clear()

def resume_writing(self):
self._not_paused.set()

def request_timeout_callback(self):
# See the docstring in the RequestTimeout exception, to see
# exactly what this timeout is checking for.
Expand Down Expand Up @@ -369,6 +378,12 @@ def write_response(self, response):
self._last_response_time = current_time
self.cleanup()

async def drain(self):
await self._not_paused.wait()

def push_data(self, data):
self.transport.write(data)

async def stream_response(self, response):
"""
Streams a response to the client asynchronously. Attaches
Expand All @@ -378,9 +393,10 @@ async def stream_response(self, response):
if self._response_timeout_handler:
self._response_timeout_handler.cancel()
self._response_timeout_handler = None

try:
keep_alive = self.keep_alive
response.transport = self.transport
response.protocol = self
await response.stream(
self.request.version, keep_alive, self.keep_alive_timeout)
self.log_response(response)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_request_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks good. Hoever, won't this require an update to the docs to use the await syntax for streaming?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. You're right.
I changed the response-streaming example, but I forgot to change the response-streaming docs.

return stream(streaming)

@app.put('/_put')
Expand All @@ -100,7 +100,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@app.patch('/_patch')
Expand All @@ -117,7 +117,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

assert app.is_request_stream is True
Expand Down Expand Up @@ -177,7 +177,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

# 404
Expand Down Expand Up @@ -231,7 +231,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@bp.put('/_put')
Expand All @@ -248,7 +248,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@bp.patch('/_patch')
Expand All @@ -265,7 +265,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

app.blueprint(bp)
Expand Down Expand Up @@ -380,7 +380,7 @@ async def streaming(response):
body = await request.stream.get()
if body is None:
break
response.write(body.decode('utf-8'))
await response.write(body.decode('utf-8'))
return stream(streaming)

@app.get('/get')
Expand Down
33 changes: 27 additions & 6 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sanic import Sanic
from sanic.response import HTTPResponse, stream, StreamingHTTPResponse, file, file_stream, json
from sanic.server import HttpProtocol
from sanic.testing import HOST, PORT
from unittest.mock import MagicMock

Expand All @@ -30,9 +31,10 @@ async def hello_route(request):


async def sample_streaming_fn(response):
response.write('foo,')
await response.write('foo,')
await asyncio.sleep(.001)
response.write('bar')
await response.write('bar')



def test_method_not_allowed():
Expand Down Expand Up @@ -168,20 +170,39 @@ def test_stream_response_includes_chunked_header():

def test_stream_response_writes_correct_content_to_transport(streaming_app):
response = StreamingHTTPResponse(sample_streaming_fn)
response.transport = MagicMock(asyncio.Transport)
response.protocol = MagicMock(HttpProtocol)
response.protocol.transport = MagicMock(asyncio.Transport)

async def mock_drain():
pass

def mock_push_data(data):
response.protocol.transport.write(data)

response.protocol.push_data = mock_push_data
response.protocol.drain = mock_drain

@streaming_app.listener('after_server_start')
async def run_stream(app, loop):
await response.stream()
assert response.transport.write.call_args_list[1][0][0] == (
# assert response.protocol.push_data.call_args_list[1][0][0] == (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these comments seem accidental -- did you mean to leave them here? They seem to be the exact same as the tests themselves

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I don't remember.
They look accidental. I will remove them.

# b'4\r\nfoo,\r\n'
# )
assert response.protocol.transport.write.call_args_list[1][0][0] == (
b'4\r\nfoo,\r\n'
)

assert response.transport.write.call_args_list[2][0][0] == (
# assert response.protocol.push_data.call_args_list[2][0][0] == (
# b'3\r\nbar\r\n'
# )
assert response.protocol.transport.write.call_args_list[2][0][0] == (
b'3\r\nbar\r\n'
)

assert response.transport.write.call_args_list[3][0][0] == (
# assert response.protocol.push_data.call_args_list[3][0][0] == (
# b'0\r\n\r\n'
# )
assert response.protocol.transport.write.call_args_list[3][0][0] == (
b'0\r\n\r\n'
)

Expand Down