Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sendfile #3383

Merged
merged 18 commits into from
Nov 8, 2018
1 change: 1 addition & 0 deletions CHANGES/3383.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix task cancellation when ``sendfile()`` syscall is used by static file handling.
66 changes: 39 additions & 27 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mimetypes
import os
import pathlib
from functools import partial
from typing import (IO, TYPE_CHECKING, Any, Awaitable, Callable, List, # noqa
Optional, Union, cast)

Expand Down Expand Up @@ -35,9 +36,15 @@ class SendfileStreamWriter(StreamWriter):
def __init__(self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
fobj: IO[Any],
count: int,
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
on_chunk_sent: _T_OnChunkSent=None) -> None:
super().__init__(protocol, loop, on_chunk_sent)
self._sendfile_buffer = [] # type: List[bytes]
self._fobj = fobj
self._count = count
self._offset = fobj.tell()
self._in_fd = fobj.fileno()

def _write(self, chunk: bytes) -> None:
# we overwrite StreamWriter._write, so nothing can be appended to
Expand All @@ -46,54 +53,57 @@ def _write(self, chunk: bytes) -> None:
self.output_size += len(chunk)
self._sendfile_buffer.append(chunk)

def _sendfile_cb(self, fut: 'asyncio.Future[None]',
out_fd: int, in_fd: int,
offset: int, count: int,
loop: asyncio.AbstractEventLoop,
registered: bool) -> None:
if registered:
loop.remove_writer(out_fd)
def _sendfile_cb(self, fut: 'asyncio.Future[None]', out_fd: int) -> None:
if fut.cancelled():
return
try:
if self._do_sendfile(out_fd):
set_result(fut, None)
except Exception as exc:
set_exception(fut, exc)

def _do_sendfile(self, out_fd: int) -> bool:
try:
n = os.sendfile(out_fd, in_fd, offset, count)
if n == 0: # EOF reached
n = count
n = os.sendfile(out_fd,
self._in_fd,
self._offset,
self._count)
if n == 0: # in_fd EOF reached
n = self._count
except (BlockingIOError, InterruptedError):
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
n = 0
except Exception as exc:
set_exception(fut, exc)
return
self.output_size += n
self._offset += n
self._count -= n
assert self._count >= 0
return self._count == 0

if n < count:
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd,
offset + n, count - n, loop, True)
else:
set_result(fut, None)
def _done_fut(self, out_fd: int, fut: 'asyncio.Future[None]') -> None:
self.loop.remove_writer(out_fd)

async def sendfile(self, fobj: IO[Any], count: int) -> None:
async def sendfile(self) -> None:
assert self.transport is not None
out_socket = self.transport.get_extra_info('socket').dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
in_fd = fobj.fileno()
offset = fobj.tell()

loop = self.loop
data = b''.join(self._sendfile_buffer)
try:
await loop.sock_sendall(out_socket, data)
fut = loop.create_future()
self._sendfile_cb(fut, out_fd, in_fd, offset, count, loop, False)
await fut
if not self._do_sendfile(out_fd):
fut = loop.create_future()
fut.add_done_callback(partial(self._done_fut, out_fd))
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd)
await fut
except asyncio.CancelledError:
raise
except Exception:
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
server_logger.debug('Socket error')
self.transport.close()
finally:
out_socket.close()

self.output_size += count
await super().write_eof()

async def write_eof(self, chunk: bytes=b'') -> None:
Expand Down Expand Up @@ -139,12 +149,14 @@ async def _sendfile_system(self, request: 'BaseRequest',
else:
writer = SendfileStreamWriter(
request.protocol,
request._loop
request._loop,
fobj,
count
)
request._payload_writer = writer

await super().prepare(request)
await writer.sendfile(fobj, count)
await writer.sendfile()

return writer

Expand Down
68 changes: 1 addition & 67 deletions tests/test_web_sendfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,7 @@

from aiohttp import hdrs
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
from aiohttp.web_fileresponse import FileResponse, SendfileStreamWriter


def test_static_handle_eof(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
m_os.sendfile.return_value = 0
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
assert fut.done()
assert fut.result() is None
assert not fake_loop.add_writer.called
assert not fake_loop.remove_writer.called


def test_static_handle_again(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
m_os.sendfile.side_effect = BlockingIOError()
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
assert not fut.done()
fake_loop.add_writer.assert_called_with(out_fd,
writer._sendfile_cb,
fut, out_fd, in_fd, 0, 100,
fake_loop, True)
assert not fake_loop.remove_writer.called


def test_static_handle_exception(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
exc = OSError()
m_os.sendfile.side_effect = exc
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
assert fut.done()
assert exc is fut.exception()
assert not fake_loop.add_writer.called
assert not fake_loop.remove_writer.called


def test__sendfile_cb_return_on_cancelling(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
fut.cancel()
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
assert fut.done()
assert not fake_loop.add_writer.called
assert not fake_loop.remove_writer.called
assert not m_os.sendfile.called
from aiohttp.web_fileresponse import FileResponse


def test_using_gzip_if_header_present_and_file_available(loop) -> None:
Expand Down
68 changes: 67 additions & 1 deletion tests/test_web_sendfile_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import os
import pathlib
import socket
import zlib

import pytest
Expand Down Expand Up @@ -324,7 +325,7 @@ def test_static_route_path_existence_check() -> None:
async def test_static_file_huge(aiohttp_client, tmpdir) -> None:
filename = 'huge_data.unknown_mime_type'

# fill 100MB file
# fill 20MB file
with tmpdir.join(filename).open('w') as f:
for i in range(1024*20):
f.write(chr(i % 64 + 0x20) * 1024)
Expand Down Expand Up @@ -751,3 +752,68 @@ async def handler(request):
assert 'application/octet-stream' == resp.headers['Content-Type']
assert resp.headers.get('Content-Encoding') == 'deflate'
await resp.release()


async def test_static_file_huge_cancel(aiohttp_client, tmpdir) -> None:
filename = 'huge_data.unknown_mime_type'

# fill 100MB file
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
with tmpdir.join(filename).open('w') as f:
for i in range(1024*20):
f.write(chr(i % 64 + 0x20) * 1024)

task = None

async def handler(request):
nonlocal task
task = request.task
# reduce send buffer size
tr = request.transport
sock = tr.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename))))
return ret

app = web.Application()

app.router.add_get('/', handler)
client = await aiohttp_client(app)

resp = await client.get('/')
assert resp.status == 200
task.cancel()
await asyncio.sleep(0)
data = b''
while True:
try:
data += await resp.content.read(1024)
except aiohttp.ClientPayloadError:
break
assert len(data) < 1024 * 1024 * 20


async def test_static_file_huge_error(aiohttp_client, tmpdir) -> None:
filename = 'huge_data.unknown_mime_type'

# fill 20MB file
with tmpdir.join(filename).open('wb') as f:
f.seek(20*1024*1024)
f.write(b'1')

async def handler(request):
# reduce send buffer size
tr = request.transport
sock = tr.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename))))
return ret

app = web.Application()

app.router.add_get('/', handler)
client = await aiohttp_client(app)

resp = await client.get('/')
assert resp.status == 200
# raise an exception on server side
resp.close()