From 75cef16cda679bfb731dd3da4b7611c482bfa0c4 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 4 Jan 2018 08:06:12 +0100 Subject: [PATCH 1/9] Get rid of legacy StreamWriter (#2623) Legacy StreamWriter as a pure proxy of the transport and the protocol is no longer needed. All of the functionalities that were behind this class has been moved to the PayloadWriter. Some changes that have to be considered that impacted during this change * TCP Operations have been isolated in a module rather than move them into the PayloadWriter * WebSocketWriter had a dependency with the StreamWriter, to get rid of that dependency the constructor has been modified to take the protocol and the transport. A next step changing the name PayLoadWriter for the StreamWriter to have consistency with the reader part, might be considered. --- aiohttp/client.py | 9 +- aiohttp/client_proto.py | 6 +- aiohttp/client_reqrep.py | 2 +- aiohttp/http.py | 3 +- aiohttp/http_websocket.py | 16 +- aiohttp/http_writer.py | 104 ++---------- aiohttp/tcp_helpers.py | 63 +++++++ aiohttp/web_fileresponse.py | 9 +- aiohttp/web_protocol.py | 29 +--- aiohttp/web_ws.py | 3 +- tests/test_client_request.py | 30 ++-- tests/test_client_ws_functional.py | 2 +- tests/test_http_stream_writer.py | 257 ----------------------------- tests/test_http_writer.py | 68 ++++---- tests/test_tcp_helpers.py | 134 +++++++++++++++ tests/test_web_protocol.py | 9 +- tests/test_web_sendfile.py | 8 +- tests/test_websocket_writer.py | 81 +++++---- 18 files changed, 351 insertions(+), 482 deletions(-) create mode 100644 aiohttp/tcp_helpers.py delete mode 100644 tests/test_http_stream_writer.py create mode 100644 tests/test_tcp_helpers.py diff --git a/aiohttp/client.py b/aiohttp/client.py index 369eb195454..73952344eca 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -30,6 +30,7 @@ from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue from .tracing import Trace +from .tcp_helpers import tcp_nodelay, tcp_cork __all__ = (client_exceptions.__all__ + # noqa @@ -296,7 +297,8 @@ async def _request(self, method, url, *, 'Connection timeout ' 'to host {0}'.format(url)) from exc - conn.writer.set_tcp_nodelay(True) + tcp_nodelay(conn.transport, True) + tcp_cork(conn.transport, False) try: resp = req.send(conn) try: @@ -575,12 +577,13 @@ async def _ws_connect(self, url, *, notakeover = False proto = resp.connection.protocol + transport = resp.connection.transport reader = FlowControlDataQueue( proto, limit=2 ** 16, loop=self._loop) proto.set_parser(WebSocketReader(reader), reader) - resp.connection.writer.set_tcp_nodelay(True) + tcp_nodelay(transport, True) writer = WebSocketWriter( - resp.connection.writer, use_mask=True, + proto, transport, use_mask=True, compress=compress, notakeover=notakeover) except Exception: resp.close() diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 5c51224fd9b..cf0d2f83306 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -4,7 +4,7 @@ from .client_exceptions import (ClientOSError, ClientPayloadError, ServerDisconnectedError) -from .http import HttpResponseParser, StreamWriter +from .http import HttpResponseParser from .streams import EMPTY_PAYLOAD, DataQueue @@ -17,7 +17,6 @@ def __init__(self, *, loop=None): self.paused = False self.transport = None - self.writer = None self._should_close = False self._message = None @@ -60,7 +59,6 @@ def is_connected(self): def connection_made(self, transport): self.transport = transport - self.writer = StreamWriter(self, transport, self._loop) def connection_lost(self, exc): if self._payload_parser is not None: @@ -82,7 +80,7 @@ def connection_lost(self, exc): exc = ServerDisconnectedError(uncompleted) DataQueue.set_exception(self, exc) - self.transport = self.writer = None + self.transport = None self._should_close = True self._parser = None self._message = None diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index de32f6622e1..2384ca3c0ca 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -469,7 +469,7 @@ def send(self, conn): if self.url.raw_query_string: path += '?' + self.url.raw_query_string - writer = PayloadWriter(conn.writer, self.loop) + writer = PayloadWriter(conn.protocol, conn.transport, self.loop) if self.compress: writer.enable_compression(self.compress) diff --git a/aiohttp/http.py b/aiohttp/http.py index 4dee43b631c..c372426754d 100644 --- a/aiohttp/http.py +++ b/aiohttp/http.py @@ -12,7 +12,7 @@ WSCloseCode, WSMessage, WSMsgType, ws_ext_gen, ws_ext_parse) from .http_writer import (HttpVersion, HttpVersion10, HttpVersion11, - PayloadWriter, StreamWriter) + PayloadWriter) __all__ = ( @@ -20,7 +20,6 @@ # .http_writer 'PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', - 'StreamWriter', # .http_parser 'HttpParser', 'HttpRequestParser', 'HttpResponseParser', diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 5bb51d69f9d..a5ca686f64e 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -513,11 +513,11 @@ def parse_frame(self, buf): class WebSocketWriter: - def __init__(self, stream, *, + def __init__(self, protocol, transport, *, use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(), compress=0, notakeover=False): - self.stream = stream - self.writer = stream.transport + self.protocol = protocol + self.transport = transport self.use_mask = use_mask self.randrange = random.randrange self.compress = compress @@ -572,20 +572,20 @@ def _send_frame(self, message, opcode, compress=None): mask = mask.to_bytes(4, 'big') message = bytearray(message) _websocket_mask(mask, message) - self.writer.write(header + mask + message) + self.transport.write(header + mask + message) self._output_size += len(header) + len(mask) + len(message) else: if len(message) > MSG_SIZE: - self.writer.write(header) - self.writer.write(message) + self.transport.write(header) + self.transport.write(message) else: - self.writer.write(header + message) + self.transport.write(header + message) self._output_size += len(header) + len(message) if self._output_size > self._limit: self._output_size = 0 - return self.stream.drain() + return self.protocol._drain_helper() return noop() diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index 0f35a2d2868..9137c987d96 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -1,104 +1,24 @@ """Http related parsers and protocol.""" import collections -import socket import zlib -from contextlib import suppress from .abc import AbstractPayloadWriter from .helpers import noop -__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', - 'StreamWriter') +__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11') HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor']) HttpVersion10 = HttpVersion(1, 0) HttpVersion11 = HttpVersion(1, 1) -if hasattr(socket, 'TCP_CORK'): # pragma: no cover - CORK = socket.TCP_CORK -elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover - CORK = socket.TCP_NOPUSH -else: # pragma: no cover - CORK = None - - -class StreamWriter: +class PayloadWriter(AbstractPayloadWriter): def __init__(self, protocol, transport, loop): self._protocol = protocol - self._loop = loop - self._tcp_nodelay = False - self._tcp_cork = False - self._socket = transport.get_extra_info('socket') - self._waiters = [] - self.transport = transport - - @property - def tcp_nodelay(self): - return self._tcp_nodelay - - def set_tcp_nodelay(self, value): - value = bool(value) - if self._tcp_nodelay == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - # socket may be closed already, on windows OSError get raised - with suppress(OSError): - if self._tcp_cork: - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False) - self._tcp_cork = False - - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, value) - self._tcp_nodelay = value - - @property - def tcp_cork(self): - return self._tcp_cork - - def set_tcp_cork(self, value): - value = bool(value) - if self._tcp_cork == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - with suppress(OSError): - if self._tcp_nodelay: - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, False) - self._tcp_nodelay = False - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value) - self._tcp_cork = value - - async def drain(self): - """Flush the write buffer. - - The intended use is to write - - await w.write(data) - await w.drain() - """ - if self._protocol.transport is not None: - await self._protocol._drain_helper() - - -class PayloadWriter(AbstractPayloadWriter): - - def __init__(self, stream, loop): - self._stream = stream - self._transport = None + self._transport = transport self.loop = loop self.length = None @@ -109,11 +29,15 @@ def __init__(self, stream, loop): self._eof = False self._compress = None self._drain_waiter = None - self._transport = self._stream.transport - async def get_transport(self): + @property + def transport(self): return self._transport + @property + def protocol(self): + return self._protocol + def enable_chunking(self): self.chunked = True @@ -200,4 +124,12 @@ async def write_eof(self, chunk=b''): self._transport = None async def drain(self): - await self._stream.drain() + """Flush the write buffer. + + The intended use is to write + + await w.write(data) + await w.drain() + """ + if self._protocol.transport is not None: + await self._protocol._drain_helper() diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py new file mode 100644 index 00000000000..c1cd5031177 --- /dev/null +++ b/aiohttp/tcp_helpers.py @@ -0,0 +1,63 @@ +"""Helper methods to tune a TCP connection""" + +import socket +from contextlib import suppress + +__all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork') + + +if hasattr(socket, 'TCP_CORK'): # pragma: no cover + CORK = socket.TCP_CORK +elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover + CORK = socket.TCP_NOPUSH +else: # pragma: no cover + CORK = None + + +if hasattr(socket, 'SO_KEEPALIVE'): + def tcp_keepalive(transport): + sock = transport.get_extra_info('socket') + if sock is not None: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) +else: + def tcp_keepalive(transport): # pragma: no cover + pass + + +def tcp_nodelay(transport, value): + sock = transport.get_extra_info('socket') + + if sock is None: + return None + + if sock.family not in (socket.AF_INET, socket.AF_INET6): + return None + + value = bool(value) + + # socket may be closed already, on windows OSError get raised + with suppress(OSError): + sock.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, value) + return value + + return None + + +def tcp_cork(transport, value): + sock = transport.get_extra_info('socket') + + if CORK is None: + return + + if sock is None: + return + + if sock.family not in (socket.AF_INET, socket.AF_INET6): + return + + value = bool(value) + + with suppress(OSError): + sock.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, False) diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 9eebae3fbb7..57f4b21e3ca 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -54,9 +54,7 @@ def _sendfile_cb(self, fut, out_fd, in_fd, set_result(fut, None) async def sendfile(self, fobj, count): - transport = await self.get_transport() - - out_socket = transport.get_extra_info('socket').dup() + out_socket = self.transport.get_extra_info('socket').dup() out_socket.setblocking(False) out_fd = out_socket.fileno() in_fd = fobj.fileno() @@ -71,7 +69,7 @@ async def sendfile(self, fobj, count): await fut except Exception: server_logger.debug('Socket error') - transport.close() + self.transport.close() finally: out_socket.close() @@ -112,7 +110,8 @@ async def _sendfile_system(self, request, fobj, count): writer = await self._sendfile_fallback(request, fobj, count) else: writer = SendfilePayloadWriter( - request._protocol.writer, + request.protocol, + transport, request.loop ) request._payload_writer = writer diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 76c944150cb..c24254a9449 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -1,7 +1,6 @@ import asyncio import asyncio.streams import http.server -import socket import traceback import warnings from collections import deque @@ -10,13 +9,13 @@ from . import helpers, http from .helpers import CeilTimeout -from .http import (HttpProcessingError, HttpRequestParser, PayloadWriter, - StreamWriter) +from .http import HttpProcessingError, HttpRequestParser, PayloadWriter from .log import access_logger, server_logger from .streams import EMPTY_PAYLOAD from .web_exceptions import HTTPException from .web_request import BaseRequest from .web_response import Response +from .tcp_helpers import tcp_keepalive, tcp_cork, tcp_nodelay __all__ = ('RequestHandler', 'RequestPayloadError') @@ -25,15 +24,6 @@ 'UNKNOWN', '/', http.HttpVersion10, {}, {}, True, False, False, False, http.URL('/')) -if hasattr(socket, 'SO_KEEPALIVE'): - def tcp_keepalive(server, transport): - sock = transport.get_extra_info('socket') - if sock is not None: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) -else: - def tcp_keepalive(server, transport): # pragma: no cover - pass - class RequestPayloadError(Exception): """Payload parsing error.""" @@ -181,13 +171,12 @@ def connection_made(self, transport): super().connection_made(transport) self.transport = transport - self.writer = StreamWriter(self, transport, self._loop) if self._tcp_keepalive: - tcp_keepalive(self, transport) + tcp_keepalive(transport) - self.writer.set_tcp_cork(False) - self.writer.set_tcp_nodelay(True) + tcp_cork(transport, False) + tcp_nodelay(transport, True) self._manager.connection_made(self, transport) def connection_lost(self, exc): @@ -200,7 +189,7 @@ def connection_lost(self, exc): self._request_factory = None self._request_handler = None self._request_parser = None - self.transport = self.writer = None + self.transport = None if self._keepalive_handle is not None: self._keepalive_handle.cancel() @@ -241,14 +230,14 @@ def data_received(self, data): # something happened during parsing self._error_handler = self._loop.create_task( self.handle_parse_error( - PayloadWriter(self.writer, self._loop), + PayloadWriter(self, self.transport, self._loop), 400, exc, exc.message)) self.close() except Exception as exc: # 500: internal error self._error_handler = self._loop.create_task( self.handle_parse_error( - PayloadWriter(self.writer, self._loop), + PayloadWriter(self, self.transport, self._loop), 500, exc)) self.close() else: @@ -371,7 +360,7 @@ async def start(self): now = loop.time() manager.requests_count += 1 - writer = PayloadWriter(self.writer, loop) + writer = PayloadWriter(self, self.transport, loop) request = self._request_factory( message, payload, self, writer, handler) try: diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 63d74bfdda5..ceaf09b4c56 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -187,7 +187,8 @@ def _pre_start(self, request): self.headers.update(headers) self.force_close() self._compress = compress - writer = WebSocketWriter(request._protocol.writer, + writer = WebSocketWriter(request._protocol, + request._protocol.transport, compress=compress, notakeover=notakeover) diff --git a/tests/test_client_request.py b/tests/test_client_request.py index f0549db6c85..d2fe2a48711 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -38,6 +38,14 @@ def buf(): return bytearray() +@pytest.fixture +def protocol(loop): + protocol = mock.Mock() + protocol._drain_helper.return_value = loop.create_future() + protocol._drain_helper.return_value.set_result(None) + return protocol + + @pytest.yield_fixture def transport(buf): transport = mock.Mock() @@ -55,22 +63,11 @@ async def write_eof(): @pytest.fixture -def conn(stream): - return mock.Mock(writer=stream) - - -@pytest.fixture -def stream(buf, transport, loop): - stream = mock.Mock() - stream.transport = transport - - def acquire(writer): - writer.set_transport(transport) - - stream.acquire.side_effect = acquire - stream.drain.return_value = loop.create_future() - stream.drain.return_value.set_result(None) - return stream +def conn(transport, protocol): + return mock.Mock( + transport=transport, + protocol=protocol + ) def test_method1(make_request): @@ -844,7 +841,6 @@ def gen(writer): assert asyncio.isfuture(req._writer) await resp.wait_for_close() assert req._writer is None - assert buf.split(b'\r\n\r\n', 1)[1] == \ b'b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n' await req.close() diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index ca395b9efe7..214a7c085a3 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -440,7 +440,7 @@ async def handler(request): await ws.prepare(request) await ws.receive_str() - ws._writer.writer.write(b'01234' * 100) + ws._writer.transport.write(b'01234' * 100) await ws.close() return ws diff --git a/tests/test_http_stream_writer.py b/tests/test_http_stream_writer.py deleted file mode 100644 index b4fdb2288a5..00000000000 --- a/tests/test_http_stream_writer.py +++ /dev/null @@ -1,257 +0,0 @@ -import socket -from unittest import mock - -import pytest - -from aiohttp.http_writer import CORK, StreamWriter - - -has_ipv6 = socket.has_ipv6 -if has_ipv6: - # The socket.has_ipv6 flag may be True if Python was built with IPv6 - # support, but the target system still may not have it. - # So let's ensure that we really have IPv6 support. - try: - socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - except OSError: - has_ipv6 = False - - -# nodelay - -def test_nodelay_and_cork_default(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - assert not writer.tcp_nodelay - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_no_change(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(False) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_exception(loop): - transport = mock.Mock() - s = mock.Mock() - s.setsockopt = mock.Mock() - s.family = socket.AF_INET - s.setsockopt.side_effect = OSError - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - - -def test_set_nodelay_enable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_enable_and_disable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - writer.set_tcp_nodelay(False) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_nodelay_and_cork(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert not writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") -def test_set_nodelay_enable_ipv6(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), - reason="requires unix sockets") -def test_set_nodelay_enable_unix(loop): - # do not set nodelay for unix socket - transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - - -def test_set_nodelay_enable_no_socket(loop): - transport = mock.Mock() - transport.get_extra_info.return_value = None - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - assert writer._socket is None - - -# cork - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_cork_default(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_no_change(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(False) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_and_disable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_cork(False) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_ipv6(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), - reason="requires unix sockets") -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_unix(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_no_socket(loop): - transport = mock.Mock() - transport.get_extra_info.return_value = None - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - assert writer._socket is None - - -def test_set_cork_exception(loop): - transport = mock.Mock() - s = mock.Mock() - s.setsockopt = mock.Mock() - s.family = socket.AF_INET - s.setsockopt.side_effect = OSError - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - - -# cork and nodelay interference - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_enabling_cork_disables_nodelay(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - writer.set_tcp_cork(True) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_enabling_nodelay_disables_cork(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 2f1339dcdc4..704c31a9f71 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -24,21 +24,16 @@ def write(chunk): @pytest.fixture -def stream(transport, loop): - stream = mock.Mock(transport=transport) +def protocol(loop, transport): + protocol = mock.Mock(transport=transport) + protocol._drain_helper.return_value = loop.create_future() + protocol._drain_helper.return_value.set_result(None) + return protocol - def acquire(writer): - writer.set_transport(transport) - stream.acquire = acquire - stream.drain.return_value = loop.create_future() - stream.drain.return_value.set_result(None) - return stream - - -async def test_write_payload_eof(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_eof(transport, protocol, loop): + write = transport.write = mock.Mock() + msg = http.PayloadWriter(protocol, transport, loop) msg.write(b'data1') msg.write(b'data2') @@ -48,8 +43,8 @@ async def test_write_payload_eof(stream, loop): assert b'data1data2' == content.split(b'\r\n\r\n', 1)[-1] -async def test_write_payload_chunked(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked(buf, protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'data') await msg.write_eof() @@ -57,8 +52,8 @@ async def test_write_payload_chunked(buf, stream, loop): assert b'4\r\ndata\r\n0\r\n\r\n' == buf -async def test_write_payload_chunked_multiple(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked_multiple(buf, protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'data1') msg.write(b'data2') @@ -67,10 +62,10 @@ async def test_write_payload_chunked_multiple(buf, stream, loop): assert b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n' == buf -async def test_write_payload_length(stream, loop): - write = stream.transport.write = mock.Mock() +async def test_write_payload_length(protocol, transport, loop): + write = transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) + msg = http.PayloadWriter(protocol, transport, loop) msg.length = 2 msg.write(b'd') msg.write(b'ata') @@ -80,10 +75,10 @@ async def test_write_payload_length(stream, loop): assert b'da' == content.split(b'\r\n\r\n', 1)[-1] -async def test_write_payload_chunked_filter(stream, loop): - write = stream.transport.write = mock.Mock() +async def test_write_payload_chunked_filter(protocol, transport, loop): + write = transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'da') msg.write(b'ta') @@ -93,9 +88,12 @@ async def test_write_payload_chunked_filter(stream, loop): assert content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n') -async def test_write_payload_chunked_filter_mutiple_chunks(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked_filter_mutiple_chunks( + protocol, + transport, + loop): + write = transport.write = mock.Mock() + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'da') msg.write(b'ta') @@ -113,9 +111,9 @@ async def test_write_payload_chunked_filter_mutiple_chunks(stream, loop): COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()]) -async def test_write_payload_deflate_compression(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_deflate_compression(protocol, transport, loop): + write = transport.write = mock.Mock() + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_compression('deflate') msg.write(b'data') await msg.write_eof() @@ -126,8 +124,12 @@ async def test_write_payload_deflate_compression(stream, loop): assert COMPRESSED == content.split(b'\r\n\r\n', 1)[-1] -async def test_write_payload_deflate_and_chunked(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_deflate_and_chunked( + buf, + protocol, + transport, + loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_compression('deflate') msg.enable_chunking() @@ -138,8 +140,8 @@ async def test_write_payload_deflate_and_chunked(buf, stream, loop): assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf -def test_write_drain(stream, loop): - msg = http.PayloadWriter(stream, loop) +def test_write_drain(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.drain = mock.Mock() msg.write(b'1' * (64 * 1024 * 2), drain=False) assert not msg.drain.called diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py new file mode 100644 index 00000000000..6abe567e837 --- /dev/null +++ b/tests/test_tcp_helpers.py @@ -0,0 +1,134 @@ +import socket +from unittest import mock + +import pytest + +from aiohttp.tcp_helpers import tcp_nodelay, tcp_cork, CORK + + +has_ipv6 = socket.has_ipv6 +if has_ipv6: + # The socket.has_ipv6 flag may be True if Python was built with IPv6 + # support, but the target system still may not have it. + # So let's ensure that we really have IPv6 support. + try: + socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + except OSError: + has_ipv6 = False + + +# nodelay + +def test_tcp_nodelay_exception(loop): + transport = mock.Mock() + s = mock.Mock() + s.setsockopt = mock.Mock() + s.family = socket.AF_INET + s.setsockopt.side_effect = OSError + transport.get_extra_info.return_value = s + assert tcp_nodelay(transport, True) is None + s.setsockopt.assert_called_with( + socket.IPPROTO_TCP, + socket.TCP_NODELAY, + True + ) + + +def test_tcp_nodelay_enable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_nodelay(transport, True) is True + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +def test_tcp_nodelay_enable_and_disable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_nodelay(transport, True) is True + assert tcp_nodelay(transport, False) is False + assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") +def test_tcp_nodelay_enable_ipv6(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_nodelay(transport, True) is True + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), + reason="requires unix sockets") +def test_tcp_nodelay_enable_unix(loop): + # do not set nodelay for unix socket + transport = mock.Mock() + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_nodelay(transport, True) is None + + +def test_tcp_nodelay_enable_no_socket(loop): + transport = mock.Mock() + transport.get_extra_info.return_value = None + assert tcp_nodelay(transport, True) is None + + +# cork + + +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_tcp_cork_enable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_cork(transport, True) is True + assert s.getsockopt(socket.IPPROTO_TCP, CORK) + + +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_and_disable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_cork(transport, True) is True + assert tcp_cork(transport, False) is False + assert not s.getsockopt(socket.IPPROTO_TCP, CORK) + + +@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_ipv6(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_cork(transport, True) is True + assert s.getsockopt(socket.IPPROTO_TCP, CORK) + + +@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), + reason="requires unix sockets") +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_unix(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + assert tcp_cork(transport, True) is True + + +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_no_socket(loop): + transport = mock.Mock() + transport.get_extra_info.return_value = None + assert tcp_cork(transport, True) is None + + +def test_set_cork_exception(loop): + transport = mock.Mock() + s = mock.Mock() + s.setsockopt = mock.Mock() + s.family = socket.AF_INET + s.setsockopt.side_effect = OSError + assert tcp_cork(transport, True) is None diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 73a4835ee7b..bafefb95d08 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -38,6 +38,8 @@ def srv(make_srv, transport): srv = make_srv() srv.connection_made(transport) transport.close.side_effect = partial(srv.connection_lost, None) + srv._drain_helper = mock.Mock() + srv._drain_helper.side_effect = helpers.noop return srv @@ -72,7 +74,7 @@ async def handle(request): @pytest.yield_fixture def writer(srv): - return http.PayloadWriter(srv.writer, srv._loop) + return http.PayloadWriter(srv, srv.transport, srv._loop) @pytest.yield_fixture @@ -83,7 +85,6 @@ def write(chunk): buf.extend(chunk) transport.write.side_effect = write - transport.drain.side_effect = helpers.noop return transport @@ -226,7 +227,7 @@ async def test_bad_method(srv, loop, buf): async def test_data_received_error(srv, loop, buf): - srv.transport = mock.Mock() + transport = srv.transport srv._request_parser = mock.Mock() srv._request_parser.feed_data.side_effect = TypeError @@ -236,7 +237,7 @@ async def test_data_received_error(srv, loop, buf): await asyncio.sleep(0, loop=loop) assert buf.startswith(b'HTTP/1.0 500 Internal Server Error\r\n') - assert srv.transport.close.called + assert transport.close.called assert srv._error_handler is None diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index 2bec965893b..7f7520ddb4f 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -12,7 +12,7 @@ def test_static_handle_eof(loop): in_fd = 31 fut = loop.create_future() m_os.sendfile.return_value = 0 - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) assert fut.done() @@ -28,7 +28,7 @@ def test_static_handle_again(loop): in_fd = 31 fut = loop.create_future() m_os.sendfile.side_effect = BlockingIOError() - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) assert not fut.done() @@ -47,7 +47,7 @@ def test_static_handle_exception(loop): fut = loop.create_future() exc = OSError() m_os.sendfile.side_effect = exc - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) assert fut.done() @@ -63,7 +63,7 @@ def test__sendfile_cb_return_on_cancelling(loop): in_fd = 31 fut = loop.create_future() fut.cancel() - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) assert fut.done() assert not fake_loop.add_writer.called diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index a31808549f6..af30b1e3910 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -7,88 +7,97 @@ @pytest.fixture -def stream(): +def protocol(): return mock.Mock() @pytest.fixture -def writer(stream): - return WebSocketWriter(stream, use_mask=False) +def transport(): + return mock.Mock() + + +@pytest.fixture +def writer(protocol, transport): + return WebSocketWriter(protocol, transport, use_mask=False) -def test_pong(stream, writer): +def test_pong(writer): writer.pong() - stream.transport.write.assert_called_with(b'\x8a\x00') + writer.transport.write.assert_called_with(b'\x8a\x00') -def test_ping(stream, writer): +def test_ping(writer): writer.ping() - stream.transport.write.assert_called_with(b'\x89\x00') + writer.transport.write.assert_called_with(b'\x89\x00') -def test_send_text(stream, writer): +def test_send_text(writer): writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x04text') + writer.transport.write.assert_called_with(b'\x81\x04text') -def test_send_binary(stream, writer): +def test_send_binary(writer): writer.send('binary', True) - stream.transport.write.assert_called_with(b'\x82\x06binary') + writer.transport.write.assert_called_with(b'\x82\x06binary') -def test_send_binary_long(stream, writer): +def test_send_binary_long(writer): writer.send(b'b' * 127, True) - assert stream.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') + assert writer.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') -def test_send_binary_very_long(stream, writer): +def test_send_binary_very_long(writer): writer.send(b'b' * 65537, True) - assert (stream.transport.write.call_args_list[0][0][0] == + assert (writer.transport.write.call_args_list[0][0][0] == b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01') - assert stream.transport.write.call_args_list[1][0][0] == b'b' * 65537 + assert writer.transport.write.call_args_list[1][0][0] == b'b' * 65537 -def test_close(stream, writer): +def test_close(writer): writer.close(1001, 'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') writer.close(1001, b'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') # Test that Service Restart close code is also supported writer.close(1012, b'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') + writer.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') -def test_send_text_masked(stream): - writer = WebSocketWriter(stream, +def test_send_text_masked(protocol, transport): + writer = WebSocketWriter(protocol, + transport, use_mask=True, random=random.Random(123)) writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') + writer.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') -def test_send_compress_text(stream): - writer = WebSocketWriter(stream, compress=15) +def test_send_compress_text(protocol, transport): + writer = WebSocketWriter(protocol, transport, compress=15) writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00') + writer.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00') -def test_send_compress_text_notakeover(stream): - writer = WebSocketWriter(stream, compress=15, notakeover=True) +def test_send_compress_text_notakeover(protocol, transport): + writer = WebSocketWriter(protocol, + transport, + compress=15, + notakeover=True) writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') -def test_send_compress_text_per_message(stream): - writer = WebSocketWriter(stream) +def test_send_compress_text_per_message(protocol, transport): + writer = WebSocketWriter(protocol, transport) writer.send(b'text', compress=15) - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x04text') + writer.transport.write.assert_called_with(b'\x81\x04text') writer.send(b'text', compress=15) - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') From 6316553291f5dba2acdf397a7628219fff964655 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 10 Jan 2018 01:18:08 +0100 Subject: [PATCH 2/9] Add CHANGES --- CHANGES/2651.removal | 1 + 1 file changed, 1 insertion(+) create mode 100644 CHANGES/2651.removal diff --git a/CHANGES/2651.removal b/CHANGES/2651.removal new file mode 100644 index 00000000000..0b5f76fd8b6 --- /dev/null +++ b/CHANGES/2651.removal @@ -0,0 +1 @@ +Get rid of the legacy class StreamWriter. From 58d14a228fb43eda4a3fbdf5b7f7587197e54732 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 10 Jan 2018 07:36:31 +0100 Subject: [PATCH 3/9] Fixed invalid import order --- aiohttp/client.py | 2 +- aiohttp/tcp_helpers.py | 1 + aiohttp/web_protocol.py | 2 +- tests/test_tcp_helpers.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index 73952344eca..f259935c536 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -29,8 +29,8 @@ from .http import WS_KEY, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue +from .tcp_helpers import tcp_cork, tcp_nodelay from .tracing import Trace -from .tcp_helpers import tcp_nodelay, tcp_cork __all__ = (client_exceptions.__all__ + # noqa diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py index c1cd5031177..8a8d8e0e62f 100644 --- a/aiohttp/tcp_helpers.py +++ b/aiohttp/tcp_helpers.py @@ -3,6 +3,7 @@ import socket from contextlib import suppress + __all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork') diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index c24254a9449..4b636cd0e34 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -12,10 +12,10 @@ from .http import HttpProcessingError, HttpRequestParser, PayloadWriter from .log import access_logger, server_logger from .streams import EMPTY_PAYLOAD +from .tcp_helpers import tcp_cork, tcp_keepalive, tcp_nodelay from .web_exceptions import HTTPException from .web_request import BaseRequest from .web_response import Response -from .tcp_helpers import tcp_keepalive, tcp_cork, tcp_nodelay __all__ = ('RequestHandler', 'RequestPayloadError') diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py index 6abe567e837..a6969a1cdfc 100644 --- a/tests/test_tcp_helpers.py +++ b/tests/test_tcp_helpers.py @@ -3,7 +3,7 @@ import pytest -from aiohttp.tcp_helpers import tcp_nodelay, tcp_cork, CORK +from aiohttp.tcp_helpers import CORK, tcp_cork, tcp_nodelay has_ipv6 = socket.has_ipv6 From 95eba9c3de7614e83ae9b6a6273591167930f476 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 10 Jan 2018 09:04:57 +0100 Subject: [PATCH 4/9] Fix test broken --- tests/test_http_writer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index dd475d6c3ec..f498e34f7cc 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -153,11 +153,11 @@ def test_write_drain(protocol, transport, loop): assert msg.buffer_size == 0 -def test_write_to_closing_transport(stream, loop): - msg = http.PayloadWriter(stream, loop) +def test_write_to_closing_transport(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.write(b'Before closing') - stream.transport.is_closing.return_value = True + transport.is_closing.return_value = True with pytest.raises(asyncio.CancelledError): msg.write(b'After closing') From 87fddb1fe13c50c84797849293e8ea7cf7a04eb1 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 10 Jan 2018 10:48:39 +0100 Subject: [PATCH 5/9] Fix tcp_cork issues --- aiohttp/tcp_helpers.py | 11 +++++++---- tests/test_tcp_helpers.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py index 8a8d8e0e62f..81faddd40b6 100644 --- a/aiohttp/tcp_helpers.py +++ b/aiohttp/tcp_helpers.py @@ -49,16 +49,19 @@ def tcp_cork(transport, value): sock = transport.get_extra_info('socket') if CORK is None: - return + return None if sock is None: - return + return None if sock.family not in (socket.AF_INET, socket.AF_INET6): - return + return None value = bool(value) with suppress(OSError): sock.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, False) + socket.IPPROTO_TCP, CORK, value) + return value + + return None diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py index a6969a1cdfc..a96260b418b 100644 --- a/tests/test_tcp_helpers.py +++ b/tests/test_tcp_helpers.py @@ -115,7 +115,7 @@ def test_set_cork_enable_unix(loop): transport = mock.Mock() s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_cork(transport, True) is True + assert tcp_cork(transport, True) is None @pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") From 4d977fa8d50458b3a41d7bd016d1435191aeea25 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 10 Jan 2018 16:02:43 +0100 Subject: [PATCH 6/9] Test PayloadWriter properties --- tests/test_http_writer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index f498e34f7cc..f208a042183 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -33,6 +33,12 @@ def protocol(loop, transport): return protocol +def test_payloadwriter_properties(transport, protocol, loop): + writer = http.PayloadWriter(protocol, transport, loop) + assert writer.protocol == protocol + assert writer.transport == transport + + async def test_write_payload_eof(transport, protocol, loop): write = transport.write = mock.Mock() msg = http.PayloadWriter(protocol, transport, loop) From 87928bea98956da46663efcd394cb5e33e245b8d Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 10 Jan 2018 22:19:26 +0100 Subject: [PATCH 7/9] Avoid return useless values for tcp_ --- aiohttp/tcp_helpers.py | 16 +++++---------- tests/test_tcp_helpers.py | 43 ++++++++++++++++++++++++--------------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py index 81faddd40b6..3a016901c9d 100644 --- a/aiohttp/tcp_helpers.py +++ b/aiohttp/tcp_helpers.py @@ -29,10 +29,10 @@ def tcp_nodelay(transport, value): sock = transport.get_extra_info('socket') if sock is None: - return None + return if sock.family not in (socket.AF_INET, socket.AF_INET6): - return None + return value = bool(value) @@ -40,28 +40,22 @@ def tcp_nodelay(transport, value): with suppress(OSError): sock.setsockopt( socket.IPPROTO_TCP, socket.TCP_NODELAY, value) - return value - - return None def tcp_cork(transport, value): sock = transport.get_extra_info('socket') if CORK is None: - return None + return if sock is None: - return None + return if sock.family not in (socket.AF_INET, socket.AF_INET6): - return None + return value = bool(value) with suppress(OSError): sock.setsockopt( socket.IPPROTO_TCP, CORK, value) - return value - - return None diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py index a96260b418b..ebe8271d820 100644 --- a/tests/test_tcp_helpers.py +++ b/tests/test_tcp_helpers.py @@ -26,7 +26,7 @@ def test_tcp_nodelay_exception(loop): s.family = socket.AF_INET s.setsockopt.side_effect = OSError transport.get_extra_info.return_value = s - assert tcp_nodelay(transport, True) is None + tcp_nodelay(transport, True) s.setsockopt.assert_called_with( socket.IPPROTO_TCP, socket.TCP_NODELAY, @@ -38,7 +38,7 @@ def test_tcp_nodelay_enable(loop): transport = mock.Mock() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_nodelay(transport, True) is True + tcp_nodelay(transport, True) assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @@ -46,8 +46,9 @@ def test_tcp_nodelay_enable_and_disable(loop): transport = mock.Mock() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_nodelay(transport, True) is True - assert tcp_nodelay(transport, False) is False + tcp_nodelay(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + tcp_nodelay(transport, False) assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @@ -56,7 +57,7 @@ def test_tcp_nodelay_enable_ipv6(loop): transport = mock.Mock() s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_nodelay(transport, True) is True + tcp_nodelay(transport, True) assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @@ -65,15 +66,16 @@ def test_tcp_nodelay_enable_ipv6(loop): def test_tcp_nodelay_enable_unix(loop): # do not set nodelay for unix socket transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_nodelay(transport, True) is None + tcp_nodelay(transport, True) + assert not s.setsockopt.called def test_tcp_nodelay_enable_no_socket(loop): transport = mock.Mock() transport.get_extra_info.return_value = None - assert tcp_nodelay(transport, True) is None + tcp_nodelay(transport, True) # cork @@ -84,7 +86,7 @@ def test_tcp_cork_enable(loop): transport = mock.Mock() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_cork(transport, True) is True + tcp_cork(transport, True) assert s.getsockopt(socket.IPPROTO_TCP, CORK) @@ -93,8 +95,9 @@ def test_set_cork_enable_and_disable(loop): transport = mock.Mock() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_cork(transport, True) is True - assert tcp_cork(transport, False) is False + tcp_cork(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, CORK) + tcp_cork(transport, False) assert not s.getsockopt(socket.IPPROTO_TCP, CORK) @@ -104,7 +107,7 @@ def test_set_cork_enable_ipv6(loop): transport = mock.Mock() s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_cork(transport, True) is True + tcp_cork(transport, True) assert s.getsockopt(socket.IPPROTO_TCP, CORK) @@ -113,22 +116,30 @@ def test_set_cork_enable_ipv6(loop): @pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") def test_set_cork_enable_unix(loop): transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM) transport.get_extra_info.return_value = s - assert tcp_cork(transport, True) is None + tcp_cork(transport, True) + assert not s.setsockopt.called @pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") def test_set_cork_enable_no_socket(loop): transport = mock.Mock() transport.get_extra_info.return_value = None - assert tcp_cork(transport, True) is None + tcp_cork(transport, True) +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") def test_set_cork_exception(loop): transport = mock.Mock() s = mock.Mock() s.setsockopt = mock.Mock() s.family = socket.AF_INET s.setsockopt.side_effect = OSError - assert tcp_cork(transport, True) is None + transport.get_extra_info.return_value = s + tcp_cork(transport, True) + s.setsockopt.assert_called_with( + socket.IPPROTO_TCP, + CORK, + True + ) From c9673ad80245dbaf4b537166d1375d2aca3d19ac Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 11 Jan 2018 00:36:01 +0100 Subject: [PATCH 8/9] Increase coverage http_writer --- tests/test_http_writer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index f208a042183..317794e7ab8 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -167,3 +167,16 @@ def test_write_to_closing_transport(protocol, transport, loop): with pytest.raises(asyncio.CancelledError): msg.write(b'After closing') + + +async def test_drain(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) + await msg.drain() + assert protocol._drain_helper.called + + +async def test_drain_no_transport(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) + msg._protocol.transport = None + await msg.drain() + assert not protocol._drain_helper.called From 884a4cfac468302d2666d20271706bc801eb671f Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 11 Jan 2018 07:56:16 +0100 Subject: [PATCH 9/9] Increase coverage web_protocol --- tests/test_web_protocol.py | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 325b6b7d19a..e68ab144242 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -132,6 +132,16 @@ async def test_double_shutdown(srv, transport): assert srv.transport is None +async def test_shutdown_wait_error_handler(loop, srv, transport): + + async def _error_handle(): + pass + + srv._error_handler = loop.create_task(_error_handle()) + await srv.shutdown() + assert srv._error_handler.done() + + async def test_close_after_response(srv, loop, transport): srv.data_received( b'GET / HTTP/1.0\r\n' @@ -738,3 +748,38 @@ def test_data_received_force_close(srv): b'Content-Length: 0\r\n\r\n') assert not srv._messages + + +async def test__process_keepalive(loop, srv): + # wait till the waiter is waiting + await asyncio.sleep(0) + + srv._keepalive_time = 1 + srv._keepalive_timeout = 1 + expired_time = srv._keepalive_time + srv._keepalive_timeout + 1 + with mock.patch.object(loop, "time", return_value=expired_time): + srv._process_keepalive() + assert srv._force_close + + +async def test__process_keepalive_schedule_next(loop, srv): + # wait till the waiter is waiting + await asyncio.sleep(0) + + srv._keepalive_time = 1 + srv._keepalive_timeout = 1 + expire_time = srv._keepalive_time + srv._keepalive_timeout + with mock.patch.object(loop, "time", return_value=expire_time): + with mock.patch.object(loop, "call_at") as call_at_patched: + srv._process_keepalive() + call_at_patched.assert_called_with( + expire_time, + srv._process_keepalive + ) + + +def test__process_keepalive_force_close(loop, srv): + srv._force_close = True + with mock.patch.object(loop, "call_at") as call_at_patched: + srv._process_keepalive() + assert not call_at_patched.called