diff --git a/CHANGES/2651.removal b/CHANGES/2651.removal new file mode 100644 index 00000000000..0b5f76fd8b6 --- /dev/null +++ b/CHANGES/2651.removal @@ -0,0 +1 @@ +Get rid of the legacy class StreamWriter. diff --git a/aiohttp/client.py b/aiohttp/client.py index 369eb195454..f259935c536 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -29,6 +29,7 @@ from .http import WS_KEY, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue +from .tcp_helpers import tcp_cork, tcp_nodelay from .tracing import Trace @@ -296,7 +297,8 @@ async def _request(self, method, url, *, 'Connection timeout ' 'to host {0}'.format(url)) from exc - conn.writer.set_tcp_nodelay(True) + tcp_nodelay(conn.transport, True) + tcp_cork(conn.transport, False) try: resp = req.send(conn) try: @@ -575,12 +577,13 @@ async def _ws_connect(self, url, *, notakeover = False proto = resp.connection.protocol + transport = resp.connection.transport reader = FlowControlDataQueue( proto, limit=2 ** 16, loop=self._loop) proto.set_parser(WebSocketReader(reader), reader) - resp.connection.writer.set_tcp_nodelay(True) + tcp_nodelay(transport, True) writer = WebSocketWriter( - resp.connection.writer, use_mask=True, + proto, transport, use_mask=True, compress=compress, notakeover=notakeover) except Exception: resp.close() diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 5c51224fd9b..cf0d2f83306 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -4,7 +4,7 @@ from .client_exceptions import (ClientOSError, ClientPayloadError, ServerDisconnectedError) -from .http import HttpResponseParser, StreamWriter +from .http import HttpResponseParser from .streams import EMPTY_PAYLOAD, DataQueue @@ -17,7 +17,6 @@ def __init__(self, *, loop=None): self.paused = False self.transport = None - self.writer = None self._should_close = False self._message = None @@ -60,7 +59,6 @@ def is_connected(self): def connection_made(self, transport): self.transport = transport - self.writer = StreamWriter(self, transport, self._loop) def connection_lost(self, exc): if self._payload_parser is not None: @@ -82,7 +80,7 @@ def connection_lost(self, exc): exc = ServerDisconnectedError(uncompleted) DataQueue.set_exception(self, exc) - self.transport = self.writer = None + self.transport = None self._should_close = True self._parser = None self._message = None diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index de32f6622e1..2384ca3c0ca 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -469,7 +469,7 @@ def send(self, conn): if self.url.raw_query_string: path += '?' + self.url.raw_query_string - writer = PayloadWriter(conn.writer, self.loop) + writer = PayloadWriter(conn.protocol, conn.transport, self.loop) if self.compress: writer.enable_compression(self.compress) diff --git a/aiohttp/http.py b/aiohttp/http.py index 4dee43b631c..c372426754d 100644 --- a/aiohttp/http.py +++ b/aiohttp/http.py @@ -12,7 +12,7 @@ WSCloseCode, WSMessage, WSMsgType, ws_ext_gen, ws_ext_parse) from .http_writer import (HttpVersion, HttpVersion10, HttpVersion11, - PayloadWriter, StreamWriter) + PayloadWriter) __all__ = ( @@ -20,7 +20,6 @@ # .http_writer 'PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', - 'StreamWriter', # .http_parser 'HttpParser', 'HttpRequestParser', 'HttpResponseParser', diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index 5bb51d69f9d..a5ca686f64e 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -513,11 +513,11 @@ def parse_frame(self, buf): class WebSocketWriter: - def __init__(self, stream, *, + def __init__(self, protocol, transport, *, use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(), compress=0, notakeover=False): - self.stream = stream - self.writer = stream.transport + self.protocol = protocol + self.transport = transport self.use_mask = use_mask self.randrange = random.randrange self.compress = compress @@ -572,20 +572,20 @@ def _send_frame(self, message, opcode, compress=None): mask = mask.to_bytes(4, 'big') message = bytearray(message) _websocket_mask(mask, message) - self.writer.write(header + mask + message) + self.transport.write(header + mask + message) self._output_size += len(header) + len(mask) + len(message) else: if len(message) > MSG_SIZE: - self.writer.write(header) - self.writer.write(message) + self.transport.write(header) + self.transport.write(message) else: - self.writer.write(header + message) + self.transport.write(header + message) self._output_size += len(header) + len(message) if self._output_size > self._limit: self._output_size = 0 - return self.stream.drain() + return self.protocol._drain_helper() return noop() diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index b253d7ed946..4b83a3ecd80 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -2,104 +2,24 @@ import asyncio import collections -import socket import zlib -from contextlib import suppress from .abc import AbstractPayloadWriter from .helpers import noop -__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', - 'StreamWriter') +__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11') HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor']) HttpVersion10 = HttpVersion(1, 0) HttpVersion11 = HttpVersion(1, 1) -if hasattr(socket, 'TCP_CORK'): # pragma: no cover - CORK = socket.TCP_CORK -elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover - CORK = socket.TCP_NOPUSH -else: # pragma: no cover - CORK = None - - -class StreamWriter: +class PayloadWriter(AbstractPayloadWriter): def __init__(self, protocol, transport, loop): self._protocol = protocol - self._loop = loop - self._tcp_nodelay = False - self._tcp_cork = False - self._socket = transport.get_extra_info('socket') - self._waiters = [] - self.transport = transport - - @property - def tcp_nodelay(self): - return self._tcp_nodelay - - def set_tcp_nodelay(self, value): - value = bool(value) - if self._tcp_nodelay == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - # socket may be closed already, on windows OSError get raised - with suppress(OSError): - if self._tcp_cork: - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False) - self._tcp_cork = False - - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, value) - self._tcp_nodelay = value - - @property - def tcp_cork(self): - return self._tcp_cork - - def set_tcp_cork(self, value): - value = bool(value) - if self._tcp_cork == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - with suppress(OSError): - if self._tcp_nodelay: - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, False) - self._tcp_nodelay = False - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value) - self._tcp_cork = value - - async def drain(self): - """Flush the write buffer. - - The intended use is to write - - await w.write(data) - await w.drain() - """ - if self._protocol.transport is not None: - await self._protocol._drain_helper() - - -class PayloadWriter(AbstractPayloadWriter): - - def __init__(self, stream, loop): - self._stream = stream - self._transport = None + self._transport = transport self.loop = loop self.length = None @@ -110,11 +30,15 @@ def __init__(self, stream, loop): self._eof = False self._compress = None self._drain_waiter = None - self._transport = self._stream.transport - async def get_transport(self): + @property + def transport(self): return self._transport + @property + def protocol(self): + return self._protocol + def enable_chunking(self): self.chunked = True @@ -204,4 +128,12 @@ async def write_eof(self, chunk=b''): self._transport = None async def drain(self): - await self._stream.drain() + """Flush the write buffer. + + The intended use is to write + + await w.write(data) + await w.drain() + """ + if self._protocol.transport is not None: + await self._protocol._drain_helper() diff --git a/aiohttp/tcp_helpers.py b/aiohttp/tcp_helpers.py new file mode 100644 index 00000000000..3a016901c9d --- /dev/null +++ b/aiohttp/tcp_helpers.py @@ -0,0 +1,61 @@ +"""Helper methods to tune a TCP connection""" + +import socket +from contextlib import suppress + + +__all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork') + + +if hasattr(socket, 'TCP_CORK'): # pragma: no cover + CORK = socket.TCP_CORK +elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover + CORK = socket.TCP_NOPUSH +else: # pragma: no cover + CORK = None + + +if hasattr(socket, 'SO_KEEPALIVE'): + def tcp_keepalive(transport): + sock = transport.get_extra_info('socket') + if sock is not None: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) +else: + def tcp_keepalive(transport): # pragma: no cover + pass + + +def tcp_nodelay(transport, value): + sock = transport.get_extra_info('socket') + + if sock is None: + return + + if sock.family not in (socket.AF_INET, socket.AF_INET6): + return + + value = bool(value) + + # socket may be closed already, on windows OSError get raised + with suppress(OSError): + sock.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, value) + + +def tcp_cork(transport, value): + sock = transport.get_extra_info('socket') + + if CORK is None: + return + + if sock is None: + return + + if sock.family not in (socket.AF_INET, socket.AF_INET6): + return + + value = bool(value) + + with suppress(OSError): + sock.setsockopt( + socket.IPPROTO_TCP, CORK, value) diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 9eebae3fbb7..57f4b21e3ca 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -54,9 +54,7 @@ def _sendfile_cb(self, fut, out_fd, in_fd, set_result(fut, None) async def sendfile(self, fobj, count): - transport = await self.get_transport() - - out_socket = transport.get_extra_info('socket').dup() + out_socket = self.transport.get_extra_info('socket').dup() out_socket.setblocking(False) out_fd = out_socket.fileno() in_fd = fobj.fileno() @@ -71,7 +69,7 @@ async def sendfile(self, fobj, count): await fut except Exception: server_logger.debug('Socket error') - transport.close() + self.transport.close() finally: out_socket.close() @@ -112,7 +110,8 @@ async def _sendfile_system(self, request, fobj, count): writer = await self._sendfile_fallback(request, fobj, count) else: writer = SendfilePayloadWriter( - request._protocol.writer, + request.protocol, + transport, request.loop ) request._payload_writer = writer diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 76c944150cb..4b636cd0e34 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -1,7 +1,6 @@ import asyncio import asyncio.streams import http.server -import socket import traceback import warnings from collections import deque @@ -10,10 +9,10 @@ from . import helpers, http from .helpers import CeilTimeout -from .http import (HttpProcessingError, HttpRequestParser, PayloadWriter, - StreamWriter) +from .http import HttpProcessingError, HttpRequestParser, PayloadWriter from .log import access_logger, server_logger from .streams import EMPTY_PAYLOAD +from .tcp_helpers import tcp_cork, tcp_keepalive, tcp_nodelay from .web_exceptions import HTTPException from .web_request import BaseRequest from .web_response import Response @@ -25,15 +24,6 @@ 'UNKNOWN', '/', http.HttpVersion10, {}, {}, True, False, False, False, http.URL('/')) -if hasattr(socket, 'SO_KEEPALIVE'): - def tcp_keepalive(server, transport): - sock = transport.get_extra_info('socket') - if sock is not None: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) -else: - def tcp_keepalive(server, transport): # pragma: no cover - pass - class RequestPayloadError(Exception): """Payload parsing error.""" @@ -181,13 +171,12 @@ def connection_made(self, transport): super().connection_made(transport) self.transport = transport - self.writer = StreamWriter(self, transport, self._loop) if self._tcp_keepalive: - tcp_keepalive(self, transport) + tcp_keepalive(transport) - self.writer.set_tcp_cork(False) - self.writer.set_tcp_nodelay(True) + tcp_cork(transport, False) + tcp_nodelay(transport, True) self._manager.connection_made(self, transport) def connection_lost(self, exc): @@ -200,7 +189,7 @@ def connection_lost(self, exc): self._request_factory = None self._request_handler = None self._request_parser = None - self.transport = self.writer = None + self.transport = None if self._keepalive_handle is not None: self._keepalive_handle.cancel() @@ -241,14 +230,14 @@ def data_received(self, data): # something happened during parsing self._error_handler = self._loop.create_task( self.handle_parse_error( - PayloadWriter(self.writer, self._loop), + PayloadWriter(self, self.transport, self._loop), 400, exc, exc.message)) self.close() except Exception as exc: # 500: internal error self._error_handler = self._loop.create_task( self.handle_parse_error( - PayloadWriter(self.writer, self._loop), + PayloadWriter(self, self.transport, self._loop), 500, exc)) self.close() else: @@ -371,7 +360,7 @@ async def start(self): now = loop.time() manager.requests_count += 1 - writer = PayloadWriter(self.writer, loop) + writer = PayloadWriter(self, self.transport, loop) request = self._request_factory( message, payload, self, writer, handler) try: diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 63d74bfdda5..ceaf09b4c56 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -187,7 +187,8 @@ def _pre_start(self, request): self.headers.update(headers) self.force_close() self._compress = compress - writer = WebSocketWriter(request._protocol.writer, + writer = WebSocketWriter(request._protocol, + request._protocol.transport, compress=compress, notakeover=notakeover) diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 5345e73ecfd..12f006bf00c 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -38,6 +38,14 @@ def buf(): return bytearray() +@pytest.fixture +def protocol(loop): + protocol = mock.Mock() + protocol._drain_helper.return_value = loop.create_future() + protocol._drain_helper.return_value.set_result(None) + return protocol + + @pytest.yield_fixture def transport(buf): transport = mock.Mock() @@ -56,22 +64,11 @@ async def write_eof(): @pytest.fixture -def conn(stream): - return mock.Mock(writer=stream) - - -@pytest.fixture -def stream(buf, transport, loop): - stream = mock.Mock() - stream.transport = transport - - def acquire(writer): - writer.set_transport(transport) - - stream.acquire.side_effect = acquire - stream.drain.return_value = loop.create_future() - stream.drain.return_value.set_result(None) - return stream +def conn(transport, protocol): + return mock.Mock( + transport=transport, + protocol=protocol + ) def test_method1(make_request): @@ -845,7 +842,6 @@ def gen(writer): assert asyncio.isfuture(req._writer) await resp.wait_for_close() assert req._writer is None - assert buf.split(b'\r\n\r\n', 1)[1] == \ b'b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n' await req.close() diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index ca395b9efe7..214a7c085a3 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -440,7 +440,7 @@ async def handler(request): await ws.prepare(request) await ws.receive_str() - ws._writer.writer.write(b'01234' * 100) + ws._writer.transport.write(b'01234' * 100) await ws.close() return ws diff --git a/tests/test_http_stream_writer.py b/tests/test_http_stream_writer.py deleted file mode 100644 index b4fdb2288a5..00000000000 --- a/tests/test_http_stream_writer.py +++ /dev/null @@ -1,257 +0,0 @@ -import socket -from unittest import mock - -import pytest - -from aiohttp.http_writer import CORK, StreamWriter - - -has_ipv6 = socket.has_ipv6 -if has_ipv6: - # The socket.has_ipv6 flag may be True if Python was built with IPv6 - # support, but the target system still may not have it. - # So let's ensure that we really have IPv6 support. - try: - socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - except OSError: - has_ipv6 = False - - -# nodelay - -def test_nodelay_and_cork_default(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - assert not writer.tcp_nodelay - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_no_change(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(False) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_exception(loop): - transport = mock.Mock() - s = mock.Mock() - s.setsockopt = mock.Mock() - s.family = socket.AF_INET - s.setsockopt.side_effect = OSError - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - - -def test_set_nodelay_enable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -def test_set_nodelay_enable_and_disable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - writer.set_tcp_nodelay(False) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_nodelay_and_cork(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert not writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") -def test_set_nodelay_enable_ipv6(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - - -@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), - reason="requires unix sockets") -def test_set_nodelay_enable_unix(loop): - # do not set nodelay for unix socket - transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - - -def test_set_nodelay_enable_no_socket(loop): - transport = mock.Mock() - transport.get_extra_info.return_value = None - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - assert not writer.tcp_nodelay - assert writer._socket is None - - -# cork - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_cork_default(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_no_change(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(False) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_and_disable(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_cork(False) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_ipv6(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), - reason="requires unix sockets") -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_unix(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_cork_enable_no_socket(loop): - transport = mock.Mock() - transport.get_extra_info.return_value = None - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - assert writer._socket is None - - -def test_set_cork_exception(loop): - transport = mock.Mock() - s = mock.Mock() - s.setsockopt = mock.Mock() - s.family = socket.AF_INET - s.setsockopt.side_effect = OSError - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - assert not writer.tcp_cork - - -# cork and nodelay interference - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_enabling_cork_disables_nodelay(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_nodelay(True) - writer.set_tcp_cork(True) - assert not writer.tcp_nodelay - assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - assert writer.tcp_cork - assert s.getsockopt(socket.IPPROTO_TCP, CORK) - - -@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") -def test_set_enabling_nodelay_disables_cork(loop): - transport = mock.Mock() - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - transport.get_extra_info.return_value = s - proto = mock.Mock() - writer = StreamWriter(proto, transport, loop) - writer.set_tcp_cork(True) - writer.set_tcp_nodelay(True) - assert writer.tcp_nodelay - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - assert not writer.tcp_cork - assert not s.getsockopt(socket.IPPROTO_TCP, CORK) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index c54f609e3f8..317794e7ab8 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -26,21 +26,22 @@ def write(chunk): @pytest.fixture -def stream(transport, loop): - stream = mock.Mock(transport=transport) +def protocol(loop, transport): + protocol = mock.Mock(transport=transport) + protocol._drain_helper.return_value = loop.create_future() + protocol._drain_helper.return_value.set_result(None) + return protocol - def acquire(writer): - writer.set_transport(transport) - stream.acquire = acquire - stream.drain.return_value = loop.create_future() - stream.drain.return_value.set_result(None) - return stream +def test_payloadwriter_properties(transport, protocol, loop): + writer = http.PayloadWriter(protocol, transport, loop) + assert writer.protocol == protocol + assert writer.transport == transport -async def test_write_payload_eof(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_eof(transport, protocol, loop): + write = transport.write = mock.Mock() + msg = http.PayloadWriter(protocol, transport, loop) msg.write(b'data1') msg.write(b'data2') @@ -50,8 +51,8 @@ async def test_write_payload_eof(stream, loop): assert b'data1data2' == content.split(b'\r\n\r\n', 1)[-1] -async def test_write_payload_chunked(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked(buf, protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'data') await msg.write_eof() @@ -59,8 +60,8 @@ async def test_write_payload_chunked(buf, stream, loop): assert b'4\r\ndata\r\n0\r\n\r\n' == buf -async def test_write_payload_chunked_multiple(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked_multiple(buf, protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'data1') msg.write(b'data2') @@ -69,10 +70,10 @@ async def test_write_payload_chunked_multiple(buf, stream, loop): assert b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n' == buf -async def test_write_payload_length(stream, loop): - write = stream.transport.write = mock.Mock() +async def test_write_payload_length(protocol, transport, loop): + write = transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) + msg = http.PayloadWriter(protocol, transport, loop) msg.length = 2 msg.write(b'd') msg.write(b'ata') @@ -82,10 +83,10 @@ async def test_write_payload_length(stream, loop): assert b'da' == content.split(b'\r\n\r\n', 1)[-1] -async def test_write_payload_chunked_filter(stream, loop): - write = stream.transport.write = mock.Mock() +async def test_write_payload_chunked_filter(protocol, transport, loop): + write = transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'da') msg.write(b'ta') @@ -95,9 +96,12 @@ async def test_write_payload_chunked_filter(stream, loop): assert content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n') -async def test_write_payload_chunked_filter_mutiple_chunks(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_chunked_filter_mutiple_chunks( + protocol, + transport, + loop): + write = transport.write = mock.Mock() + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_chunking() msg.write(b'da') msg.write(b'ta') @@ -115,9 +119,9 @@ async def test_write_payload_chunked_filter_mutiple_chunks(stream, loop): COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()]) -async def test_write_payload_deflate_compression(stream, loop): - write = stream.transport.write = mock.Mock() - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_deflate_compression(protocol, transport, loop): + write = transport.write = mock.Mock() + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_compression('deflate') msg.write(b'data') await msg.write_eof() @@ -128,8 +132,12 @@ async def test_write_payload_deflate_compression(stream, loop): assert COMPRESSED == content.split(b'\r\n\r\n', 1)[-1] -async def test_write_payload_deflate_and_chunked(buf, stream, loop): - msg = http.PayloadWriter(stream, loop) +async def test_write_payload_deflate_and_chunked( + buf, + protocol, + transport, + loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.enable_compression('deflate') msg.enable_chunking() @@ -140,8 +148,8 @@ async def test_write_payload_deflate_and_chunked(buf, stream, loop): assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf -def test_write_drain(stream, loop): - msg = http.PayloadWriter(stream, loop) +def test_write_drain(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.drain = mock.Mock() msg.write(b'1' * (64 * 1024 * 2), drain=False) assert not msg.drain.called @@ -151,11 +159,24 @@ def test_write_drain(stream, loop): assert msg.buffer_size == 0 -def test_write_to_closing_transport(stream, loop): - msg = http.PayloadWriter(stream, loop) +def test_write_to_closing_transport(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) msg.write(b'Before closing') - stream.transport.is_closing.return_value = True + transport.is_closing.return_value = True with pytest.raises(asyncio.CancelledError): msg.write(b'After closing') + + +async def test_drain(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) + await msg.drain() + assert protocol._drain_helper.called + + +async def test_drain_no_transport(protocol, transport, loop): + msg = http.PayloadWriter(protocol, transport, loop) + msg._protocol.transport = None + await msg.drain() + assert not protocol._drain_helper.called diff --git a/tests/test_tcp_helpers.py b/tests/test_tcp_helpers.py new file mode 100644 index 00000000000..ebe8271d820 --- /dev/null +++ b/tests/test_tcp_helpers.py @@ -0,0 +1,145 @@ +import socket +from unittest import mock + +import pytest + +from aiohttp.tcp_helpers import CORK, tcp_cork, tcp_nodelay + + +has_ipv6 = socket.has_ipv6 +if has_ipv6: + # The socket.has_ipv6 flag may be True if Python was built with IPv6 + # support, but the target system still may not have it. + # So let's ensure that we really have IPv6 support. + try: + socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + except OSError: + has_ipv6 = False + + +# nodelay + +def test_tcp_nodelay_exception(loop): + transport = mock.Mock() + s = mock.Mock() + s.setsockopt = mock.Mock() + s.family = socket.AF_INET + s.setsockopt.side_effect = OSError + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + s.setsockopt.assert_called_with( + socket.IPPROTO_TCP, + socket.TCP_NODELAY, + True + ) + + +def test_tcp_nodelay_enable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +def test_tcp_nodelay_enable_and_disable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + tcp_nodelay(transport, False) + assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") +def test_tcp_nodelay_enable_ipv6(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + +@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), + reason="requires unix sockets") +def test_tcp_nodelay_enable_unix(loop): + # do not set nodelay for unix socket + transport = mock.Mock() + s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_nodelay(transport, True) + assert not s.setsockopt.called + + +def test_tcp_nodelay_enable_no_socket(loop): + transport = mock.Mock() + transport.get_extra_info.return_value = None + tcp_nodelay(transport, True) + + +# cork + + +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_tcp_cork_enable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_cork(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, CORK) + + +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_and_disable(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_cork(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, CORK) + tcp_cork(transport, False) + assert not s.getsockopt(socket.IPPROTO_TCP, CORK) + + +@pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_ipv6(loop): + transport = mock.Mock() + s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_cork(transport, True) + assert s.getsockopt(socket.IPPROTO_TCP, CORK) + + +@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'), + reason="requires unix sockets") +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_unix(loop): + transport = mock.Mock() + s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM) + transport.get_extra_info.return_value = s + tcp_cork(transport, True) + assert not s.setsockopt.called + + +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_enable_no_socket(loop): + transport = mock.Mock() + transport.get_extra_info.return_value = None + tcp_cork(transport, True) + + +@pytest.mark.skipif(CORK is None, reason="TCP_CORK or TCP_NOPUSH required") +def test_set_cork_exception(loop): + transport = mock.Mock() + s = mock.Mock() + s.setsockopt = mock.Mock() + s.family = socket.AF_INET + s.setsockopt.side_effect = OSError + transport.get_extra_info.return_value = s + tcp_cork(transport, True) + s.setsockopt.assert_called_with( + socket.IPPROTO_TCP, + CORK, + True + ) diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 00a772afe9f..e68ab144242 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -38,6 +38,8 @@ def srv(make_srv, transport): srv = make_srv() srv.connection_made(transport) transport.close.side_effect = partial(srv.connection_lost, None) + srv._drain_helper = mock.Mock() + srv._drain_helper.side_effect = helpers.noop return srv @@ -72,7 +74,7 @@ async def handle(request): @pytest.yield_fixture def writer(srv): - return http.PayloadWriter(srv.writer, srv._loop) + return http.PayloadWriter(srv, srv.transport, srv._loop) @pytest.yield_fixture @@ -83,7 +85,6 @@ def write(chunk): buf.extend(chunk) transport.write.side_effect = write - transport.drain.side_effect = helpers.noop transport.is_closing.return_value = False return transport @@ -131,6 +132,16 @@ async def test_double_shutdown(srv, transport): assert srv.transport is None +async def test_shutdown_wait_error_handler(loop, srv, transport): + + async def _error_handle(): + pass + + srv._error_handler = loop.create_task(_error_handle()) + await srv.shutdown() + assert srv._error_handler.done() + + async def test_close_after_response(srv, loop, transport): srv.data_received( b'GET / HTTP/1.0\r\n' @@ -227,7 +238,7 @@ async def test_bad_method(srv, loop, buf): async def test_data_received_error(srv, loop, buf): - srv.transport = mock.Mock() + transport = srv.transport srv._request_parser = mock.Mock() srv._request_parser.feed_data.side_effect = TypeError @@ -237,7 +248,7 @@ async def test_data_received_error(srv, loop, buf): await asyncio.sleep(0, loop=loop) assert buf.startswith(b'HTTP/1.0 500 Internal Server Error\r\n') - assert srv.transport.close.called + assert transport.close.called assert srv._error_handler is None @@ -737,3 +748,38 @@ def test_data_received_force_close(srv): b'Content-Length: 0\r\n\r\n') assert not srv._messages + + +async def test__process_keepalive(loop, srv): + # wait till the waiter is waiting + await asyncio.sleep(0) + + srv._keepalive_time = 1 + srv._keepalive_timeout = 1 + expired_time = srv._keepalive_time + srv._keepalive_timeout + 1 + with mock.patch.object(loop, "time", return_value=expired_time): + srv._process_keepalive() + assert srv._force_close + + +async def test__process_keepalive_schedule_next(loop, srv): + # wait till the waiter is waiting + await asyncio.sleep(0) + + srv._keepalive_time = 1 + srv._keepalive_timeout = 1 + expire_time = srv._keepalive_time + srv._keepalive_timeout + with mock.patch.object(loop, "time", return_value=expire_time): + with mock.patch.object(loop, "call_at") as call_at_patched: + srv._process_keepalive() + call_at_patched.assert_called_with( + expire_time, + srv._process_keepalive + ) + + +def test__process_keepalive_force_close(loop, srv): + srv._force_close = True + with mock.patch.object(loop, "call_at") as call_at_patched: + srv._process_keepalive() + assert not call_at_patched.called diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index 2bec965893b..7f7520ddb4f 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -12,7 +12,7 @@ def test_static_handle_eof(loop): in_fd = 31 fut = loop.create_future() m_os.sendfile.return_value = 0 - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) assert fut.done() @@ -28,7 +28,7 @@ def test_static_handle_again(loop): in_fd = 31 fut = loop.create_future() m_os.sendfile.side_effect = BlockingIOError() - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) assert not fut.done() @@ -47,7 +47,7 @@ def test_static_handle_exception(loop): fut = loop.create_future() exc = OSError() m_os.sendfile.side_effect = exc - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100) assert fut.done() @@ -63,7 +63,7 @@ def test__sendfile_cb_return_on_cancelling(loop): in_fd = 31 fut = loop.create_future() fut.cancel() - writer = SendfilePayloadWriter(fake_loop, mock.Mock()) + writer = SendfilePayloadWriter(mock.Mock(), mock.Mock(), fake_loop) writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) assert fut.done() assert not fake_loop.add_writer.called diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index a31808549f6..af30b1e3910 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -7,88 +7,97 @@ @pytest.fixture -def stream(): +def protocol(): return mock.Mock() @pytest.fixture -def writer(stream): - return WebSocketWriter(stream, use_mask=False) +def transport(): + return mock.Mock() + + +@pytest.fixture +def writer(protocol, transport): + return WebSocketWriter(protocol, transport, use_mask=False) -def test_pong(stream, writer): +def test_pong(writer): writer.pong() - stream.transport.write.assert_called_with(b'\x8a\x00') + writer.transport.write.assert_called_with(b'\x8a\x00') -def test_ping(stream, writer): +def test_ping(writer): writer.ping() - stream.transport.write.assert_called_with(b'\x89\x00') + writer.transport.write.assert_called_with(b'\x89\x00') -def test_send_text(stream, writer): +def test_send_text(writer): writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x04text') + writer.transport.write.assert_called_with(b'\x81\x04text') -def test_send_binary(stream, writer): +def test_send_binary(writer): writer.send('binary', True) - stream.transport.write.assert_called_with(b'\x82\x06binary') + writer.transport.write.assert_called_with(b'\x82\x06binary') -def test_send_binary_long(stream, writer): +def test_send_binary_long(writer): writer.send(b'b' * 127, True) - assert stream.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') + assert writer.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') -def test_send_binary_very_long(stream, writer): +def test_send_binary_very_long(writer): writer.send(b'b' * 65537, True) - assert (stream.transport.write.call_args_list[0][0][0] == + assert (writer.transport.write.call_args_list[0][0][0] == b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01') - assert stream.transport.write.call_args_list[1][0][0] == b'b' * 65537 + assert writer.transport.write.call_args_list[1][0][0] == b'b' * 65537 -def test_close(stream, writer): +def test_close(writer): writer.close(1001, 'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') writer.close(1001, b'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') # Test that Service Restart close code is also supported writer.close(1012, b'msg') - stream.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') + writer.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') -def test_send_text_masked(stream): - writer = WebSocketWriter(stream, +def test_send_text_masked(protocol, transport): + writer = WebSocketWriter(protocol, + transport, use_mask=True, random=random.Random(123)) writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') + writer.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') -def test_send_compress_text(stream): - writer = WebSocketWriter(stream, compress=15) +def test_send_compress_text(protocol, transport): + writer = WebSocketWriter(protocol, transport, compress=15) writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00') + writer.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00') -def test_send_compress_text_notakeover(stream): - writer = WebSocketWriter(stream, compress=15, notakeover=True) +def test_send_compress_text_notakeover(protocol, transport): + writer = WebSocketWriter(protocol, + transport, + compress=15, + notakeover=True) writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') writer.send(b'text') - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') -def test_send_compress_text_per_message(stream): - writer = WebSocketWriter(stream) +def test_send_compress_text_per_message(protocol, transport): + writer = WebSocketWriter(protocol, transport) writer.send(b'text', compress=15) - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') writer.send(b'text') - stream.transport.write.assert_called_with(b'\x81\x04text') + writer.transport.write.assert_called_with(b'\x81\x04text') writer.send(b'text', compress=15) - stream.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') + writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')