From 6d3866e3b1d3d34c79fa2f709ba2c0170fc70a41 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 13 Feb 2017 17:47:27 -0800 Subject: [PATCH] drop parser --- aiohttp/__init__.py | 2 - aiohttp/_ws_impl.py | 187 +---------- aiohttp/client_proto.py | 21 +- aiohttp/client_reqrep.py | 2 - aiohttp/connector.py | 5 +- aiohttp/parsers.py | 523 ------------------------------ aiohttp/protocol.py | 202 +++--------- aiohttp/pytest_plugin.py | 1 + aiohttp/server.py | 39 ++- aiohttp/streams.py | 138 +++++++- aiohttp/web_server.py | 2 +- aiohttp/web_ws.py | 4 +- tests/test_client_functional.py | 5 +- tests/test_client_request.py | 5 +- tests/test_flowcontrol_streams.py | 60 ++-- tests/test_http_parser.py | 423 +++++++----------------- tests/test_parser_buffer.py | 252 -------------- tests/test_protocol.py | 219 +++++++------ tests/test_py35/test_client.py | 6 +- tests/test_server.py | 17 +- tests/test_stream_parser.py | 373 --------------------- tests/test_stream_protocol.py | 40 --- tests/test_stream_writer.py | 56 ++-- tests/test_web_exceptions.py | 2 +- tests/test_web_response.py | 1 + tests/test_websocket_parser.py | 362 +++++++-------------- tests/test_websocket_writer.py | 46 +-- tests/test_wsgi.py | 10 +- 28 files changed, 666 insertions(+), 2337 deletions(-) delete mode 100644 aiohttp/parsers.py delete mode 100644 tests/test_parser_buffer.py delete mode 100644 tests/test_stream_parser.py delete mode 100644 tests/test_stream_protocol.py diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 48da00f85fc..6b652a5f9ae 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -13,7 +13,6 @@ from .client_reqrep import * # noqa from .errors import * # noqa from .helpers import * # noqa -from .parsers import * # noqa from .streams import * # noqa from .multipart import * # noqa from .client_ws import ClientWebSocketResponse # noqa @@ -30,7 +29,6 @@ client_reqrep.__all__ + # noqa errors.__all__ + # noqa helpers.__all__ + # noqa - parsers.__all__ + # noqa protocol.__all__ + # noqa connector.__all__ + # noqa streams.__all__ + # noqa diff --git a/aiohttp/_ws_impl.py b/aiohttp/_ws_impl.py index 7ac99dc481f..96744730c25 100644 --- a/aiohttp/_ws_impl.py +++ b/aiohttp/_ws_impl.py @@ -14,7 +14,7 @@ from aiohttp import errors, hdrs from aiohttp.log import ws_logger -__all__ = ('WebSocketParser', 'WebSocketWriter', 'do_handshake', +__all__ = ('WebSocketReader', 'WebSocketWriter', 'do_handshake', 'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode') @@ -102,107 +102,6 @@ def __init__(self, code, message): super().__init__(message) -def WebSocketParser(out, buf): - while True: - fin, opcode, payload = yield from parse_frame(buf) - - if opcode == WSMsgType.CLOSE: - if len(payload) >= 2: - close_code = UNPACK_CLOSE_CODE(payload[:2])[0] - if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Invalid close code: {}'.format(close_code)) - try: - close_message = payload[2:].decode('utf-8') - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, - 'Invalid UTF-8 text message') from exc - msg = WSMessage(WSMsgType.CLOSE, close_code, close_message) - elif payload: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Invalid close frame: {} {} {!r}'.format( - fin, opcode, payload)) - else: - msg = WSMessage(WSMsgType.CLOSE, 0, '') - - out.feed_data(msg, 0) - - elif opcode == WSMsgType.PING: - out.feed_data(WSMessage(WSMsgType.PING, payload, ''), len(payload)) - - elif opcode == WSMsgType.PONG: - out.feed_data(WSMessage(WSMsgType.PONG, payload, ''), len(payload)) - - elif opcode not in (WSMsgType.TEXT, WSMsgType.BINARY): - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Unexpected opcode={!r}".format(opcode)) - else: - # load text/binary - data = [payload] - - while not fin: - fin, _opcode, payload = yield from parse_frame(buf, True) - - # We can receive ping/close in the middle of - # text message, Case 5.* - if _opcode == WSMsgType.PING: - out.feed_data( - WSMessage(WSMsgType.PING, payload, ''), len(payload)) - fin, _opcode, payload = yield from parse_frame(buf, True) - elif _opcode == WSMsgType.CLOSE: - if len(payload) >= 2: - close_code = UNPACK_CLOSE_CODE(payload[:2])[0] - if (close_code not in ALLOWED_CLOSE_CODES and - close_code < 3000): - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Invalid close code: {}'.format(close_code)) - try: - close_message = payload[2:].decode('utf-8') - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, - 'Invalid UTF-8 text message') from exc - msg = WSMessage(WSMsgType.CLOSE, close_code, - close_message) - elif payload: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Invalid close frame: {} {} {!r}'.format( - fin, opcode, payload)) - else: - msg = WSMessage(WSMsgType.CLOSE, 0, '') - - out.feed_data(msg, 0) - fin, _opcode, payload = yield from parse_frame(buf, True) - - if _opcode != WSMsgType.CONTINUATION: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'The opcode in non-fin frame is expected ' - 'to be zero, got {!r}'.format(_opcode)) - else: - data.append(payload) - - if opcode == WSMsgType.TEXT: - try: - text = b''.join(data).decode('utf-8') - out.feed_data(WSMessage(WSMsgType.TEXT, text, ''), - len(text)) - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, - 'Invalid UTF-8 text message') from exc - else: - data = b''.join(data) - out.feed_data( - WSMessage(WSMsgType.BINARY, data, ''), len(data)) - - native_byteorder = sys.byteorder @@ -240,70 +139,6 @@ def _websocket_mask_python(mask, data): _websocket_mask = _websocket_mask_python -def parse_frame(buf, continuation=False): - """Return the next frame from the socket.""" - # read header - data = yield from buf.read(2) - first_byte, second_byte = data - - fin = (first_byte >> 7) & 1 - rsv1 = (first_byte >> 6) & 1 - rsv2 = (first_byte >> 5) & 1 - rsv3 = (first_byte >> 4) & 1 - opcode = first_byte & 0xf - - # frame-fin = %x0 ; more frames of this message follow - # / %x1 ; final frame of this message - # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise - if rsv1 or rsv2 or rsv3: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Received frame with non-zero reserved bits') - - if opcode > 0x7 and fin == 0: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Received fragmented control frame') - - if fin == 0 and opcode == WSMsgType.CONTINUATION and not continuation: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - 'Received new fragment frame with non-zero ' - 'opcode {!r}'.format(opcode)) - - has_mask = (second_byte >> 7) & 1 - length = (second_byte) & 0x7f - - # Control frames MUST have a payload length of 125 bytes or less - if opcode > 0x7 and length > 125: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Control frame payload cannot be larger than 125 bytes") - - # read payload - if length == 126: - data = yield from buf.read(2) - length = UNPACK_LEN2(data)[0] - elif length > 126: - data = yield from buf.read(8) - length = UNPACK_LEN3(data)[0] - - if has_mask: - mask = yield from buf.read(4) - - if length: - payload = yield from buf.read(length) - else: - payload = bytearray() - - if has_mask: - payload = _websocket_mask(bytes(mask), payload) - - return fin, opcode, payload - - class WSParserState(IntEnum): READ_HEADER = 1 READ_PAYLOAD_LENGTH = 2 @@ -320,6 +155,7 @@ def __init__(self, queue): self._partial = [] self._state = WSParserState.READ_HEADER + self._opcode = None self._frame_fin = False self._frame_opcode = None self._frame_payload = bytearray() @@ -346,7 +182,6 @@ def feed_data(self, data): def _feed_data(self, data): for fin, opcode, payload in self.parse_frame(data): - if opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] @@ -380,7 +215,8 @@ def _feed_data(self, data): self.queue.feed_data( WSMessage(WSMsgType.PONG, payload, ''), len(payload)) - elif opcode not in (WSMsgType.TEXT, WSMsgType.BINARY): + elif opcode not in ( + WSMsgType.TEXT, WSMsgType.BINARY) and not self._opcode: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Unexpected opcode={!r}".format(opcode)) @@ -389,6 +225,8 @@ def _feed_data(self, data): if not fin: # got partial frame payload + if opcode != WSMsgType.CONTINUATION: + self._opcode = opcode self._partial.append(payload) else: # previous frame was non finished @@ -400,6 +238,9 @@ def _feed_data(self, data): 'The opcode in non-fin frame is expected ' 'to be zero, got {!r}'.format(opcode)) + if opcode == WSMsgType.CONTINUATION: + opcode = self._opcode + self._partial.append(payload) if opcode == WSMsgType.TEXT: @@ -416,6 +257,7 @@ def _feed_data(self, data): self.queue.feed_data( WSMessage(WSMsgType.BINARY, data, ''), len(data)) + self._start_opcode = None self._partial.clear() return False, b'' @@ -565,9 +407,10 @@ def parse_frame(self, buf, continuation=False): class WebSocketWriter: - def __init__(self, writer, *, + def __init__(self, stream, *, use_mask=False, limit=DEFAULT_LIMIT, random=random.Random()): - self.writer = writer + self.stream = stream + self.writer = stream.transport self.use_mask = use_mask self.randrange = random.randrange self._closing = False @@ -610,7 +453,7 @@ def _send_frame(self, message, opcode): if self._output_size > self._limit: self._output_size = 0 - return self.writer.drain() + return self.stream.drain() return () @@ -720,6 +563,6 @@ def do_handshake(method, headers, transport, # response code, headers, parser, writer, protocol return (101, response_headers, - WebSocketParser, + None, WebSocketWriter(transport, limit=write_buffer_size), protocol) diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index d9d8d412929..280e6b297bc 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -1,12 +1,11 @@ import asyncio import asyncio.streams -import socket -from . import errors, hdrs, streams +from . import errors, hdrs from .errors import ServerDisconnectedError -from .streams import DataQueue, FlowControlStreamReader, EmptyStreamReader -from .parsers import StreamParser, StreamWriter -from .protocol import HttpResponseParser, HttpPayloadParser +from .protocol import HttpPayloadParser, HttpResponseParser +from .streams import (DataQueue, EmptyStreamReader, FlowControlStreamReader, + StreamWriter) EMPTY_PAYLOAD = EmptyStreamReader() @@ -47,7 +46,7 @@ def is_connected(self): def connection_made(self, transport): self.transport = transport - self.writer = StreamWriter(transport, self, None, self._loop) + self.writer = StreamWriter(self, transport, self._loop) def connection_lost(self, exc): self.transport = self.writer = None @@ -149,7 +148,8 @@ def data_received(self, data, # calculate payload empty_payload = True - if (((length is not None and length > 0) or msg.chunked) and + if (((length is not None and length > 0) or + msg.chunked) and (not self._skip_payload and msg.code not in self._skip_status_codes)): @@ -157,9 +157,12 @@ def data_received(self, data, payload = FlowControlStreamReader( self, timer=self._timer, loop=self._loop) payload_parser = HttpPayloadParser( - msg, readall=self._read_until_eof) + payload, length=length, + chunked=msg.chunked, code=msg.code, + compression=msg.compression, + readall=self._read_until_eof) - if payload_parser.start(length, payload): + if not payload_parser.done: empty_payload = False self._payload = payload self._payload_parser = payload_parser diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index c48837a607d..ae66a8eec7c 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -679,8 +679,6 @@ def release(self, *, consume=False): self._closed = True if self._connection is not None: self._connection.release() - #if self._reader is not None: - #self._reader.unset_parser() self._connection = None self._cleanup_writer() self._notify_content() diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 8aab22c5f4a..321e6aa9a88 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -7,14 +7,11 @@ from hashlib import md5, sha1, sha256 from types import MappingProxyType -import aiohttp - from . import hdrs, helpers from .client import ClientRequest from .client_proto import HttpClientProtocol from .errors import (ClientOSError, ClientTimeoutError, FingerprintMismatch, - HttpProxyError, ProxyConnectionError, - ServerDisconnectedError) + HttpProxyError, ProxyConnectionError) from .helpers import SimpleCookie, is_ip_address, sentinel from .resolver import DefaultResolver diff --git a/aiohttp/parsers.py b/aiohttp/parsers.py deleted file mode 100644 index 5aad7ce3fa6..00000000000 --- a/aiohttp/parsers.py +++ /dev/null @@ -1,523 +0,0 @@ -"""Parser is a generator function (NOT coroutine). - -Parser receives data with generator's send() method and sends data to -destination DataQueue. Parser receives ParserBuffer and DataQueue objects -as a parameters of the parser call, all subsequent send() calls should -send bytes objects. Parser sends parsed `term` to destination buffer with -DataQueue.feed_data() method. DataQueue object should implement two methods. -feed_data() - parser uses this method to send parsed protocol data. -feed_eof() - parser uses this method for indication of end of parsing stream. -To indicate end of incoming data stream EofStream exception should be sent -into parser. Parser could throw exceptions. - -There are three stages: - - * Data flow chain: - - 1. Application creates StreamParser object for storing incoming data. - 2. StreamParser creates ParserBuffer as internal data buffer. - 3. Application create parser and set it into stream buffer: - - parser = HttpRequestParser() - data_queue = stream.set_parser(parser) - - 3. At this stage StreamParser creates DataQueue object and passes it - and internal buffer into parser as an arguments. - - def set_parser(self, parser): - output = DataQueue() - self.p = parser(output, self._input) - return output - - 4. Application waits data on output.read() - - while True: - msg = yield from output.read() - ... - - * Data flow: - - 1. asyncio's transport reads data from socket and sends data to protocol - with data_received() call. - 2. Protocol sends data to StreamParser with feed_data() call. - 3. StreamParser sends data into parser with generator's send() method. - 4. Parser processes incoming data and sends parsed data - to DataQueue with feed_data() - 5. Application received parsed data from DataQueue.read() - - * Eof: - - 1. StreamParser receives eof with feed_eof() call. - 2. StreamParser throws EofStream exception into parser. - 3. Then it unsets parser. - -_SocketSocketTransport -> - -> "protocol" -> StreamParser -> "parser" -> DataQueue <- "application" - -""" - -import asyncio -import asyncio.streams -import socket - -from . import errors -from .streams import EofStream, FlowControlDataQueue - -__all__ = ('EofStream', 'StreamParser', 'StreamProtocol', - 'ParserBuffer', 'StreamWriter') - -DEFAULT_LIMIT = 2 ** 16 - -if hasattr(socket, 'TCP_CORK'): # pragma: no cover - CORK = socket.TCP_CORK -elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover - CORK = socket.TCP_NOPUSH -else: # pragma: no cover - CORK = None - - -class StreamParser: - """StreamParser manages incoming bytes stream and protocol parsers. - - StreamParser uses ParserBuffer as internal buffer. - - set_parser() sets current parser, it creates DataQueue object - and sends ParserBuffer and DataQueue into parser generator. - - unset_parser() sends EofStream into parser and then removes it. - """ - - def __init__(self, *, loop=None, buf=None, - limit=DEFAULT_LIMIT, eof_exc_class=RuntimeError): - self._loop = loop - self._eof = False - self._exception = None - self._parser = None - self._output = None - self._limit = limit - self._eof_exc_class = eof_exc_class - self._buffer = buf if buf is not None else ParserBuffer() - - self.paused = False - self.transport = None - - @property - def output(self): - return self._output - - def set_transport(self, transport): - assert transport is None or self.transport is None, \ - 'Transport already set' - self.transport = transport - - def at_eof(self): - return self._eof - - def exception(self): - return self._exception - - def set_exception(self, exc): - if isinstance(exc, ConnectionError): - exc, old_exc = self._eof_exc_class(), exc - exc.__cause__ = old_exc - exc.__context__ = old_exc - - self._exception = exc - - if self._output is not None: - self._output.set_exception(exc) - self._output = None - self._parser = None - - def feed_data(self, data): - """send data to current parser or store in buffer.""" - if data is None: - return - - if self._parser: - try: - self._parser.send(data) - except StopIteration: - self._output.feed_eof() - self._output = None - self._parser = None - except Exception as exc: - self._output.set_exception(exc) - self._output = None - self._parser = None - else: - self._buffer.feed_data(data) - - def feed_eof(self): - """send eof to all parsers, recursively.""" - if self._parser: - try: - if self._buffer: - self._parser.send(b'') - self._parser.throw(EofStream()) - except StopIteration: - self._output.feed_eof() - except EofStream: - self._output.set_exception(self._eof_exc_class()) - except Exception as exc: - self._output.set_exception(exc) - - self._parser = None - self._output = None - - self._eof = True - - def set_parser(self, parser, output=None): - """set parser to stream. return parser's DataQueue.""" - if self._parser: - self.unset_parser() - - if output is None: - output = FlowControlDataQueue( - self, limit=self._limit, loop=self._loop) - - if self._exception: - output.set_exception(self._exception) - return output - - # init parser - p = parser(output, self._buffer) - - try: - # initialize parser with data and parser buffers - next(p) - except StopIteration: - pass - except Exception as exc: - output.set_exception(exc) - else: - # parser still require more data - self._parser = p - self._output = output - self._output._allow_pause = True # stricktly internal use! - - if self._eof: - self.unset_parser() - - return output - - def unset_parser(self): - """unset parser, send eof to the parser and then remove it.""" - if self._parser is None: - return - - # TODO: write test - if self._loop.is_closed(): - # TODO: log something - return - - try: - self._parser.throw(EofStream()) - except StopIteration: - self._output.feed_eof() - except EofStream: - self._output.set_exception(self._eof_exc_class()) - except Exception as exc: - self._output.set_exception(exc) - finally: - self._output = None - self._parser = None - - -class StreamWriter(asyncio.streams.StreamWriter): - - def __init__(self, transport, protocol, reader, loop): - self._transport = transport - self._protocol = protocol - self._reader = reader - self._loop = loop - self._tcp_nodelay = False - self._tcp_cork = False - self._socket = transport.get_extra_info('socket') - self._waiters = [] - self.available = True - - def acquire(self, cb): - if self.available: - self.available = False - cb(self) - else: - self._waiters.append(cb) - - def release(self): - if self._waiters: - self.available = False - cb = self._waiters.pop(0) - cb(self) - else: - self.available = True - - @property - def tcp_nodelay(self): - return self._tcp_nodelay - - def set_tcp_nodelay(self, value): - value = bool(value) - if self._tcp_nodelay == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - # socket may be closed already, on windows OSError get raised - try: - if self._tcp_cork: - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False) - self._tcp_cork = False - - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, value) - self._tcp_nodelay = value - except OSError: - pass - - @property - def tcp_cork(self): - return self._tcp_cork - - def set_tcp_cork(self, value): - value = bool(value) - if self._tcp_cork == value: - return - if self._socket is None: - return - if self._socket.family not in (socket.AF_INET, socket.AF_INET6): - return - - try: - if self._tcp_nodelay: - self._socket.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, False) - self._tcp_nodelay = False - if CORK is not None: # pragma: no branch - self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value) - self._tcp_cork = value - except OSError: - pass - - -class StreamProtocol(asyncio.streams.FlowControlMixin, asyncio.Protocol): - """Helper class to adapt between Protocol and StreamReader.""" - - def __init__(self, *, loop=None, disconnect_error=RuntimeError, **kwargs): - super().__init__(loop=loop) - - self.transport = None - self.writer = None - self.buffer = ParserBuffer() - self.reader = StreamParser( - buf=self.buffer, loop=loop, - eof_exc_class=disconnect_error, **kwargs) - - def is_connected(self): - return self.transport is not None - - def connection_made(self, transport): - self.transport = transport - self.reader.set_transport(transport) - self.writer = StreamWriter(transport, self, self.reader, self._loop) - - def connection_lost(self, exc): - self.transport = self.writer = None - self.reader.set_transport(None) - - if exc is None: - self.reader.feed_eof() - else: - self.reader.set_exception(exc) - - super().connection_lost(exc) - - def data_received(self, data): - self.reader.feed_data(data) - - def eof_received(self): - self.reader.feed_eof() - - -class _ParserBufferHelper: - - __slots__ = ('exception', 'data') - - def __init__(self, exception, data): - self.exception = exception - self.data = data - - -class ParserBuffer: - """ParserBuffer is NOT a bytearray extension anymore. - - ParserBuffer provides helper methods for parsers. - """ - __slots__ = ('_helper', '_writer', '_data') - - def __init__(self, *args): - self._data = bytearray(*args) - self._helper = _ParserBufferHelper(None, self._data) - self._writer = self._feed_data(self._helper) - next(self._writer) - - def exception(self): - return self._helper.exception - - def set_exception(self, exc): - self._helper.exception = exc - - @staticmethod - def _feed_data(helper): - while True: - chunk = yield - if chunk: - helper.data.extend(chunk) - - if helper.exception: - raise helper.exception - - def feed_data(self, data): - if not self._helper.exception: - self._writer.send(data) - - def read(self, size): - """read() reads specified amount of bytes.""" - - while True: - if self._helper.exception: - raise self._helper.exception - - if len(self._data) >= size: - data = self._data[:size] - del self._data[:size] - return data - - self._writer.send((yield)) - - def readsome(self, size=None): - """reads size of less amount of bytes.""" - - while True: - if self._helper.exception: - raise self._helper.exception - - length = len(self._data) - if length > 0: - if size is None or length < size: - size = length - - data = self._data[:size] - del self._data[:size] - return data - - self._writer.send((yield)) - - def readuntil(self, stop, limit=None): - assert isinstance(stop, bytes) and stop, \ - 'bytes is required: {!r}'.format(stop) - - stop_len = len(stop) - - while True: - if self._helper.exception: - raise self._helper.exception - - pos = self._data.find(stop) - if pos >= 0: - end = pos + stop_len - size = end - if limit is not None and size > limit: - raise errors.LineLimitExceededParserError( - 'Line is too long.', limit) - - data = self._data[:size] - del self._data[:size] - return data - else: - if limit is not None and len(self._data) > limit: - raise errors.LineLimitExceededParserError( - 'Line is too long.', limit) - - self._writer.send((yield)) - - def wait(self, size): - """wait() waits for specified amount of bytes - then returns data without changing internal buffer.""" - - while True: - if self._helper.exception: - raise self._helper.exception - - if len(self._data) >= size: - return self._data[:size] - - self._writer.send((yield)) - - def waituntil(self, stop, limit=None): - """waituntil() reads until `stop` bytes sequence.""" - assert isinstance(stop, bytes) and stop, \ - 'bytes is required: {!r}'.format(stop) - - stop_len = len(stop) - - while True: - if self._helper.exception: - raise self._helper.exception - - pos = self._data.find(stop) - if pos >= 0: - size = pos + stop_len - if limit is not None and size > limit: - raise errors.LineLimitExceededParserError( - 'Line is too long. %s' % bytes(self._data), limit) - - return self._data[:size] - else: - if limit is not None and len(self._data) > limit: - raise errors.LineLimitExceededParserError( - 'Line is too long. %s' % bytes(self._data), limit) - - self._writer.send((yield)) - - def skip(self, size): - """skip() skips specified amount of bytes.""" - - while len(self._data) < size: - if self._helper.exception: - raise self._helper.exception - - self._writer.send((yield)) - - del self._data[:size] - - def skipuntil(self, stop): - """skipuntil() reads until `stop` bytes sequence.""" - assert isinstance(stop, bytes) and stop, \ - 'bytes is required: {!r}'.format(stop) - - stop_len = len(stop) - - while True: - if self._helper.exception: - raise self._helper.exception - - stop_line = self._data.find(stop) - if stop_line >= 0: - size = stop_line + stop_len - del self._data[:size] - return - - self._writer.send((yield)) - - def extend(self, data): - self._data.extend(data) - - def __len__(self): - return len(self._data) - - def __bytes__(self): - return bytes(self._data) diff --git a/aiohttp/protocol.py b/aiohttp/protocol.py index 4c9515d674a..94fc8c276d5 100644 --- a/aiohttp/protocol.py +++ b/aiohttp/protocol.py @@ -8,6 +8,7 @@ import sys import zlib from abc import ABC, abstractmethod +from enum import IntEnum from wsgiref.handlers import format_date_time from multidict import CIMultiDict, istr @@ -40,10 +41,12 @@ PARSE_CHUNKED = 2 PARSE_UNTIL_EOF = 3 -PARSE_CHUNKED_SIZE = 0 -PARSE_CHUNKED_CHUNK = 1 -PARSE_CHUNKED_CHUNK_EOF = 2 -PARSE_CHUNKED_TRAILERS = 3 + +class ChunkState(IntEnum): + PARSE_CHUNKED_SIZE = 0 + PARSE_CHUNKED_CHUNK = 1 + PARSE_CHUNKED_CHUNK_EOF = 2 + PARSE_CHUNKED_TRAILERS = 3 HttpVersion = collections.namedtuple( @@ -219,19 +222,6 @@ def parse_message(self, lines): method, path, version, headers, raw_headers, close, compression, upgrade, chunked) - def __call__(self, out, buf): - # read HTTP message (request line + headers) - try: - raw_data = yield from buf.readuntil( - b'\r\n\r\n', self.max_headers) - except errors.LineLimitExceededParserError as exc: - raise errors.LineTooLong('request header', exc.limit) from None - - lines = raw_data.split(b'\r\n') - - out.feed_data(self.parse_message(lines), len(raw_data)) - out.feed_eof() - class HttpResponseParser(HttpParser): """Read response status line and headers. @@ -277,86 +267,71 @@ def parse_message(self, lines): version, status, reason.strip(), headers, raw_headers, close, compression, upgrade, chunked) - def __call__(self, out, buf): - # read HTTP message (response line + headers) - try: - raw_data = yield from buf.readuntil( - b'\r\n\r\n', self.max_line_size + self.max_headers) - except errors.LineLimitExceededParserError as exc: - raise errors.LineTooLong('response header', exc.limit) from None - - lines = raw_data.split(b'\r\n') - - out.feed_data(self.parse_message(lines), len(raw_data)) - out.feed_eof() - class HttpPayloadParser: - def __init__(self, message, length=None, compression=True, + def __init__(self, payload, + length=None, chunked=False, compression=None, + code=None, method=None, readall=False, response_with_body=True): - self.message = message - self.length = length - self.compression = compression - self.readall = readall - self.response_with_body = response_with_body - self.payload = None + self.payload = payload + self._length = 0 self._type = PARSE_NONE - self._chunk = PARSE_CHUNKED_SIZE + self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 self._chunk_tail = b'' + self.done = False - def start(self, length, payload): # payload decompression wrapper - if (self.response_with_body and - self.compression and self.message.compression): - payload = DeflateBuffer(payload, self.message.compression) + if (response_with_body and compression): + payload = DeflateBuffer(payload, compression) # payload parser - if not self.response_with_body: + if not response_with_body: # don't parse payload if it's not expected to be received self._type = PARSE_NONE payload.feed_eof() - return False - elif self.message.chunked: + self.done = True + + elif chunked: self._type = PARSE_CHUNKED elif length is not None: self._type = PARSE_LENGTH - self.length = length + self._length = length + if self._length == 0: + payload.feed_eof() + self.done = True else: - if self.readall and getattr(self.message, 'code', 0) != 204: - self.length = None + if readall and code != 204: self._type = PARSE_UNTIL_EOF - elif getattr(self.message, 'method', None) in ('PUT', 'POST'): + elif method in ('PUT', 'POST'): internal_logger.warning( # pragma: no cover 'Content-Length or Transfer-Encoding header is required') self._type = PARSE_NONE payload.feed_eof() - return False + self.done = True self.payload = payload - return True def feed_eof(self): if self._type == PARSE_UNTIL_EOF: self.payload.feed_eof() def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): - # Read specified amount of bytes if self._type == PARSE_LENGTH: - required = self.length + required = self._length chunk_len = len(chunk) if required >= chunk_len: - self.length = required - chunk_len + self._length = required - chunk_len self.payload.feed_data(chunk, chunk_len) - if self.length == 0: + if self._length == 0: self.payload.feed_eof() return True, b'' else: - self.length = 0 + self._length = 0 self.payload.feed_data(chunk[:required], required) self.payload.feed_eof() return True, chunk[required:] @@ -365,11 +340,12 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): elif self._type == PARSE_CHUNKED: if self._chunk_tail: chunk = self._chunk_tail + chunk + self._chunk_tail = b'' while chunk: # read next chunk size - if self._chunk == PARSE_CHUNKED_SIZE: + if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: pos = chunk.find(SEP) if pos >= 0: if pos > 0: @@ -384,28 +360,29 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): try: size = int(size, 16) except ValueError: - raise errors.TransferEncodingError( - chunk[:pos]) from None + exc = errors.TransferEncodingError(chunk[:pos]) + self.payload.set_exception(exc) + raise exc from None chunk = chunk[pos+2:] if size == 0: # eof marker - self._chunk = PARSE_CHUNKED_TRAILERS + self._chunk = ChunkState.PARSE_CHUNKED_TRAILERS else: - self._chunk = PARSE_CHUNKED_CHUNK + self._chunk = ChunkState.PARSE_CHUNKED_CHUNK self._chunk_size = size else: self._chunk_tail = chunk return False, None # read chunk and feed buffer - if self._chunk == PARSE_CHUNKED_CHUNK: + if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: required = self._chunk_size chunk_len = len(chunk) if required >= chunk_len: self._chunk_size = required - chunk_len if self._chunk_size == 0: - self._chunk = PARSE_CHUNKED_CHUNK_EOF + self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF self.payload.feed_data(chunk, chunk_len) return False, None @@ -413,19 +390,19 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): self._chunk_size = 0 self.payload.feed_data(chunk[:required], required) chunk = chunk[required:] - self._chunk = PARSE_CHUNKED_CHUNK_EOF + self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF # toss the CRLF at the end of the chunk - if self._chunk == PARSE_CHUNKED_CHUNK_EOF: + if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF: if chunk[:2] == SEP: chunk = chunk[2:] - self._chunk = PARSE_CHUNKED_SIZE + self._chunk = ChunkState.PARSE_CHUNKED_SIZE else: self._chunk_tail = chunk return False, None # read and discard trailer up to the CRLF terminator - if self._chunk == PARSE_CHUNKED_TRAILERS: + if self._chunk == ChunkState.PARSE_CHUNKED_TRAILERS: pos = chunk.find(SEP) if pos >= 0: self.payload.feed_eof() @@ -440,93 +417,6 @@ def feed_data(self, chunk, SEP=b'\r\n', CHUNK_EXT=b';'): return False, None - def __call__(self, out, buf): - # payload params - length = self.message.headers.get(hdrs.CONTENT_LENGTH, self.length) - if hdrs.SEC_WEBSOCKET_KEY1 in self.message.headers: - length = 8 - - # payload decompression wrapper - if (self.response_with_body and - self.compression and self.message.compression): - out = DeflateBuffer(out, self.message.compression) - - # payload parser - if not self.response_with_body: - # don't parse payload if it's not expected to be received - pass - - elif 'chunked' in self.message.headers.get( - hdrs.TRANSFER_ENCODING, ''): - yield from self.parse_chunked_payload(out, buf) - - elif length is not None: - try: - length = int(length) - except ValueError: - raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None - - if length < 0: - raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) - elif length > 0: - yield from self.parse_length_payload(out, buf, length) - else: - if self.readall and getattr(self.message, 'code', 0) != 204: - yield from self.parse_eof_payload(out, buf) - elif getattr(self.message, 'method', None) in ('PUT', 'POST'): - internal_logger.warning( # pragma: no cover - 'Content-Length or Transfer-Encoding header is required') - - out.feed_eof() - - def parse_chunked_payload(self, out, buf): - """Chunked transfer encoding parser.""" - while True: - # read next chunk size - line = yield from buf.readuntil(b'\r\n', 8192) - - i = line.find(b';') - if i >= 0: - line = line[:i] # strip chunk-extensions - else: - line = line.strip() - try: - size = int(line, 16) - except ValueError: - raise errors.TransferEncodingError(line) from None - - if size == 0: # eof marker - break - - # read chunk and feed buffer - while size: - chunk = yield from buf.readsome(size) - out.feed_data(chunk, len(chunk)) - size = size - len(chunk) - - # toss the CRLF at the end of the chunk - yield from buf.skip(2) - - # read and discard trailer up to the CRLF terminator - yield from buf.skipuntil(b'\r\n') - - def parse_length_payload(self, out, buf, length=0): - """Read specified amount of bytes.""" - required = length - while required: - chunk = yield from buf.readsome(required) - out.feed_data(chunk, len(chunk)) - required -= len(chunk) - - def parse_eof_payload(self, out, buf): - """Read all bytes until eof.""" - try: - while True: - chunk = yield from buf.readsome() - out.feed_data(chunk, len(chunk)) - except aiohttp.EofStream: - pass - class DeflateBuffer: """DeflateStream decompress stream and feed data into specified stream.""" @@ -580,7 +470,7 @@ def __init__(self, stream, loop): self._drain_waiter = None if self._stream.available: - self._transport = self._stream + self._transport = self._stream.transport self._stream.available = False else: self._stream.acquire(self.set_transport) @@ -708,7 +598,7 @@ def drain(self): if self._buffer: self._transport.write(b''.join(self._buffer)) self._buffer.clear() - yield from self._transport.drain() + yield from self._stream.drain() else: if self._buffer: if self._drain_waiter is None: diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index 13fce812c45..808067600e2 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -52,6 +52,7 @@ def pytest_pyfunc_call(pyfuncitem): def loop(): """Return an instance of the event loop.""" with loop_context() as _loop: + _loop.set_debug(True) yield _loop diff --git a/aiohttp/server.py b/aiohttp/server.py index 6c03265f510..3db645d901b 100644 --- a/aiohttp/server.py +++ b/aiohttp/server.py @@ -1,6 +1,7 @@ """simple HTTP server.""" import asyncio +import asyncio.streams import http.server import socket import traceback @@ -14,6 +15,7 @@ from aiohttp.helpers import TimeService, create_future, ensure_future from aiohttp.log import access_logger, server_logger from aiohttp.protocol import HttpPayloadParser +from aiohttp.streams import StreamWriter __all__ = ('ServerHttpProtocol',) @@ -42,7 +44,7 @@ def tcp_keepalive(server, transport): # pragma: no cover EMPTY_PAYLOAD = streams.EmptyStreamReader() -class ServerHttpProtocol(aiohttp.StreamProtocol): +class ServerHttpProtocol(asyncio.streams.FlowControlMixin, asyncio.Protocol): """Simple HTTP protocol implementation. ServerHttpProtocol handles incoming HTTP request. It reads request line, @@ -108,9 +110,7 @@ def __init__(self, *, loop=None, warnings.warn( 'slow_request_timeout is deprecated', DeprecationWarning) - super().__init__( - loop=loop, - disconnect_error=errors.ClientDisconnectedError, **kwargs) + super().__init__(loop=loop) self._loop = loop if loop is not None else asyncio.get_event_loop() if time_service is not None: @@ -142,6 +142,8 @@ def __init__(self, *, loop=None, max_field_size=max_field_size, max_headers=max_headers) + self.transport = None + self.logger = logger self.debug = debug self.access_log = access_log @@ -204,6 +206,9 @@ def shutdown(self, timeout=15.0): def connection_made(self, transport): super().connection_made(transport) + self.transport = transport + self.writer = StreamWriter(self, transport, self._loop) + if self._tcp_keepalive: tcp_keepalive(self, transport) @@ -213,6 +218,7 @@ def connection_lost(self, exc): super().connection_lost(exc) self._closing = True + self.transport = self.writer = None if self._payload_parser is not None: self._payload_parser.feed_eof() @@ -234,6 +240,9 @@ def set_parser(self, parser): self._payload_parser = parser + def eof_received(self): + pass + def data_received(self, data, SEP=b'\r\n', CONTENT_LENGTH=hdrs.CONTENT_LENGTH, @@ -316,20 +325,22 @@ def data_received(self, data, if ((length is not None and length > 0) or msg.chunked): payload = streams.FlowControlStreamReader( - self.reader, loop=self._loop) - payload_parser = HttpPayloadParser(msg) - - if payload_parser.start(length, payload): + self, loop=self._loop) + payload_parser = HttpPayloadParser( + payload, length=length, + chunked=msg.chunked, method=msg.method, + compression=msg.compression) + if not payload_parser.done: empty_payload = False self._payload_parser = payload_parser elif msg.method == METH_CONNECT: empty_payload = False payload = streams.FlowControlStreamReader( - self.reader, loop=self._loop) - payload_parser = HttpPayloadParser( - msg, readall=True) - payload_parser.start(length, payload) - self._payload_parser = payload_parser + self, loop=self._loop) + self._payload_parser = HttpPayloadParser( + payload, method=msg.method, + compression=msg.compression, + readall=True) else: payload = EMPTY_PAYLOAD @@ -360,7 +371,7 @@ def data_received(self, data, elif self._payload_parser is None and self._conn_upgraded: assert not self._message_lines if data: - super().data_received(data) + self._message_tail += data # feed payload else: diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 897d623b8eb..8230947759f 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -1,6 +1,7 @@ import asyncio import collections import functools +import socket import sys import traceback @@ -8,7 +9,7 @@ from .log import internal_logger __all__ = ( - 'EofStream', 'StreamReader', 'DataQueue', 'ChunksQueue', + 'EofStream', 'StreamReader', 'StreamWriter', 'DataQueue', 'ChunksQueue', 'FlowControlStreamReader', 'FlowControlDataQueue', 'FlowControlChunksQueue') @@ -18,10 +19,119 @@ DEFAULT_LIMIT = 2 ** 16 +if hasattr(socket, 'TCP_CORK'): # pragma: no cover + CORK = socket.TCP_CORK +elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover + CORK = socket.TCP_NOPUSH +else: # pragma: no cover + CORK = None + + class EofStream(Exception): """eof stream indication.""" +class StreamWriter: + + def __init__(self, protocol, transport, loop): + self._protocol = protocol + self._loop = loop + self._tcp_nodelay = False + self._tcp_cork = False + self._socket = transport.get_extra_info('socket') + self._waiters = [] + self.available = True + self.transport = transport + + def acquire(self, cb): + if self.available: + self.available = False + cb(self.transport) + else: + self._waiters.append(cb) + + def release(self): + if self._waiters: + self.available = False + cb = self._waiters.pop(0) + cb(self) + else: + self.available = True + + def is_connected(self): + return self.transport is not None + + @property + def tcp_nodelay(self): + return self._tcp_nodelay + + def set_tcp_nodelay(self, value): + value = bool(value) + if self._tcp_nodelay == value: + return + if self._socket is None: + return + if self._socket.family not in (socket.AF_INET, socket.AF_INET6): + return + + # socket may be closed already, on windows OSError get raised + try: + if self._tcp_cork: + if CORK is not None: # pragma: no branch + self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False) + self._tcp_cork = False + + self._socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, value) + self._tcp_nodelay = value + except OSError: + pass + + @property + def tcp_cork(self): + return self._tcp_cork + + def set_tcp_cork(self, value): + value = bool(value) + if self._tcp_cork == value: + return + if self._socket is None: + return + if self._socket.family not in (socket.AF_INET, socket.AF_INET6): + return + + try: + if self._tcp_nodelay: + self._socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, False) + self._tcp_nodelay = False + if CORK is not None: # pragma: no branch + self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value) + self._tcp_cork = value + except OSError: + pass + + @asyncio.coroutine + def drain(self): + """Flush the write buffer. + + The intended use is to write + + w.write(data) + yield from w.drain() + """ + if self.transport is not None: + if self.transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + yield + yield from self._protocol._drain_helper() + + if PY_35: class AsyncStreamIterator: @@ -533,7 +643,7 @@ def __init__(self, stream, limit=DEFAULT_LIMIT, *args, **kwargs): self._allow_pause = False # resume transport reading - if stream.paused: + if stream._paused: try: self._stream.transport.resume_reading() except (AttributeError, NotImplementedError): @@ -543,14 +653,14 @@ def __init__(self, stream, limit=DEFAULT_LIMIT, *args, **kwargs): self._allow_pause = True def _check_buffer_size(self): - if self._stream.paused: + if self._stream._paused: if self._buffer_size < self._b_limit: try: self._stream.transport.resume_reading() except (AttributeError, NotImplementedError): pass else: - self._stream.paused = False + self._stream._paused = False else: if self._buffer_size > self._b_limit: try: @@ -558,21 +668,21 @@ def _check_buffer_size(self): except (AttributeError, NotImplementedError): pass else: - self._stream.paused = True + self._stream._paused = True def feed_data(self, data, size=0): has_waiter = self._waiter is not None and not self._waiter.cancelled() super().feed_data(data) - if (self._allow_pause and not self._stream.paused and + if (self._allow_pause and not self._stream._paused and not has_waiter and self._buffer_size > self._b_limit): try: self._stream.transport.pause_reading() except (AttributeError, NotImplementedError): pass else: - self._stream.paused = True + self._stream._paused = True @maybe_resume @asyncio.coroutine @@ -612,13 +722,13 @@ def __init__(self, stream, *, limit=DEFAULT_LIMIT, loop=None): self._allow_pause = False # resume transport reading - if stream.paused: + if stream._paused: try: self._stream.transport.resume_reading() except (AttributeError, NotImplementedError): pass else: - self._stream.paused = False + self._stream._paused = False self._allow_pause = True def feed_data(self, data, size): @@ -626,27 +736,27 @@ def feed_data(self, data, size): super().feed_data(data, size) - if (self._allow_pause and not self._stream.paused and + if (self._allow_pause and not self._stream._paused and not has_waiter and self._size > self._limit): try: self._stream.transport.pause_reading() except (AttributeError, NotImplementedError): pass else: - self._stream.paused = True + self._stream._paused = True @asyncio.coroutine def read(self): result = yield from super().read() - if self._stream.paused: + if self._stream._paused: if self._size < self._limit: try: self._stream.transport.resume_reading() except (AttributeError, NotImplementedError): pass else: - self._stream.paused = False + self._stream._paused = False else: if self._size > self._limit: try: @@ -654,7 +764,7 @@ def read(self): except (AttributeError, NotImplementedError): pass else: - self._stream.paused = True + self._stream._paused = True return result diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index f966565250b..b35a20d2201 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -152,7 +152,7 @@ def connection_lost(self, handler, exc=None): def _make_request(self, message, payload, protocol): return BaseRequest( message, payload, - protocol.transport, protocol.reader, protocol.writer, + protocol.transport, protocol, protocol.writer, protocol.time_service, protocol._request_handler, loop=self._loop) @asyncio.coroutine diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index deca823d363..c6566e3a590 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -107,7 +107,7 @@ def prepare(self, request): def _pre_start(self, request): try: status, headers, parser, writer, protocol = do_handshake( - request.method, request.headers, request.transport_pair[1], + request.method, request.headers, request._writer, self._protocols) except HttpProcessingError as err: if err.code == 405: @@ -133,7 +133,7 @@ def _post_start(self, request, parser, protocol, writer): self._loop = request.app.loop self._writer = writer self._reader = FlowControlDataQueue( - request._reader.reader, limit=2 ** 16, loop=self._loop) + request._reader, limit=2 ** 16, loop=self._loop) request._reader.set_parser(WebSocketReader(self._reader)) def can_prepare(self, request): diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index f15e537dcde..896b8d1f253 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -611,7 +611,6 @@ def handler(request): yield from resp.read() - @asyncio.coroutine def test_readline_error_on_conn_close(loop, test_client): @@ -1644,9 +1643,9 @@ def test_request_conn_error(loop): yield from client.close() -#@pytest.mark.xfail +@pytest.mark.xfail @asyncio.coroutine -def _test_broken_connection(loop, test_client): +def test_broken_connection(loop, test_client): @asyncio.coroutine def handler(request): request.transport.close() diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 5d0e8a6f559..b8a57dea757 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -47,6 +47,7 @@ def write(chunk): transport.acquire.side_effect = acquire transport.write.side_effect = write transport.writer.write.side_effect = write + transport.writer.transport.write.side_effect = write transport.writer.drain.return_value = () transport.drain.return_value = () @@ -848,7 +849,7 @@ def exc(): helpers.ensure_future(exc(), loop=loop) conn = mock.Mock(acquire=acquire) - resp = req.send(conn) + req.send(conn) yield from req._writer # assert conn.close.called assert conn.protocol.set_exception.called @@ -891,7 +892,7 @@ def exc(): helpers.ensure_future(exc(), loop=loop) connection, _ = transport - resp = req.send(connection) + req.send(connection) yield from req._writer # assert connection.close.called assert connection.protocol.set_exception.called diff --git a/tests/test_flowcontrol_streams.py b/tests/test_flowcontrol_streams.py index e6e20fbc9cd..c6cb1c56d4a 100644 --- a/tests/test_flowcontrol_streams.py +++ b/tests/test_flowcontrol_streams.py @@ -8,7 +8,7 @@ class TestFlowControlStreamReader(unittest.TestCase): def setUp(self): - self.stream = mock.Mock(paused=False) + self.stream = mock.Mock(_paused=False) self.transp = self.stream.transport self.loop = asyncio.new_event_loop() asyncio.set_event_loop(None) @@ -24,7 +24,7 @@ def _make_one(self, allow_pause=True, *args, **kwargs): def test_read(self): r = self._make_one() - r._stream.paused = True + r._stream._paused = True r.feed_data(b'da', 2) res = self.loop.run_until_complete(r.read(1)) self.assertEqual(res, b'd') @@ -33,7 +33,7 @@ def test_read(self): def test_pause_on_read(self): r = self._make_one() r.feed_data(b'test', 4) - r._stream.paused = False + r._stream._paused = False res = self.loop.run_until_complete(r.read(1)) self.assertEqual(res, b't') @@ -41,7 +41,7 @@ def test_pause_on_read(self): def test_readline(self): r = self._make_one() - r._stream.paused = True + r._stream._paused = True r.feed_data(b'data\n', 5) res = self.loop.run_until_complete(r.readline()) self.assertEqual(res, b'data\n') @@ -49,7 +49,7 @@ def test_readline(self): def test_readany(self): r = self._make_one() - r._stream.paused = True + r._stream._paused = True r.feed_data(b'data', 4) res = self.loop.run_until_complete(r.readany()) self.assertEqual(res, b'data') @@ -57,7 +57,7 @@ def test_readany(self): def test_readexactly(self): r = self._make_one() - r._stream.paused = True + r._stream._paused = True r.feed_data(b'data', 4) res = self.loop.run_until_complete(r.readexactly(3)) self.assertEqual(res, b'dat') @@ -65,77 +65,77 @@ def test_readexactly(self): def test_feed_data(self): r = self._make_one() - r._stream.paused = False + r._stream._paused = False r.feed_data(b'datadata', 8) self.assertTrue(self.transp.pause_reading.called) def test_feed_data_no_allow_pause(self): r = self._make_one() r._allow_pause = False - r._stream.paused = False + r._stream._paused = False r.feed_data(b'datadata', 8) self.assertFalse(self.transp.pause_reading.called) def test_read_nowait(self): r = self._make_one() - r._stream.paused = False + r._stream._paused = False r.feed_data(b'data1', 5) r.feed_data(b'data2', 5) r.feed_data(b'data3', 5) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) res = self.loop.run_until_complete(r.read(5)) self.assertTrue(res == b'data1') # _buffer_size > _buffer_limit self.assertTrue(self.transp.pause_reading.call_count == 1) self.assertTrue(self.transp.resume_reading.call_count == 0) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) - r._stream.paused = False + r._stream._paused = False res = r.read_nowait(5) self.assertTrue(res == b'data2') # _buffer_size > _buffer_limit self.assertTrue(self.transp.pause_reading.call_count == 2) self.assertTrue(self.transp.resume_reading.call_count == 0) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) res = r.read_nowait(5) self.assertTrue(res == b'data3') # _buffer_size < _buffer_limit self.assertTrue(self.transp.pause_reading.call_count == 2) self.assertTrue(self.transp.resume_reading.call_count == 1) - self.assertTrue(not self.stream.paused) + self.assertTrue(not self.stream._paused) res = r.read_nowait(5) self.assertTrue(res == b'') # _buffer_size < _buffer_limit self.assertTrue(self.transp.pause_reading.call_count == 2) self.assertTrue(self.transp.resume_reading.call_count == 1) - self.assertTrue(not self.stream.paused) + self.assertTrue(not self.stream._paused) def test_rudimentary_transport(self): self.transp.resume_reading.side_effect = NotImplementedError() self.transp.pause_reading.side_effect = NotImplementedError() - self.stream.paused = True + self.stream._paused = True r = self._make_one() self.assertTrue(self.transp.pause_reading.call_count == 0) self.assertTrue(self.transp.resume_reading.call_count == 1) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) r.feed_data(b'data', 4) res = self.loop.run_until_complete(r.read(4)) self.assertTrue(self.transp.pause_reading.call_count == 0) self.assertTrue(self.transp.resume_reading.call_count == 2) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) self.assertTrue(res == b'data') - self.stream.paused = False + self.stream._paused = False r.feed_data(b'data', 4) res = self.loop.run_until_complete(r.read(1)) self.assertTrue(self.transp.pause_reading.call_count == 2) self.assertTrue(self.transp.resume_reading.call_count == 2) - self.assertTrue(not self.stream.paused) + self.assertTrue(not self.stream._paused) self.assertTrue(res == b'd') @@ -143,11 +143,11 @@ class FlowControlMixin: def test_resume_on_init(self): stream = mock.Mock() - stream.paused = True + stream._paused = True streams.FlowControlDataQueue(stream, limit=1, loop=self.loop) self.assertTrue(stream.transport.resume_reading.called) - self.assertFalse(stream.paused) + self.assertFalse(stream._paused) def test_no_transport_in_init(self): stream = mock.Mock() @@ -189,38 +189,38 @@ def cb(): def test_resume_on_read(self): out = self._make_one() out.feed_data(object(), 100) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) self.loop.run_until_complete(out.read()) self.assertTrue(self.stream.transport.resume_reading.called) - self.assertFalse(self.stream.paused) + self.assertFalse(self.stream._paused) def test_resume_on_read_no_transport(self): item = object() out = self._make_one() out.feed_data(item, 100) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) self.stream.transport = None res = self.loop.run_until_complete(out.read()) self.assertIs(res, item) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) def test_no_resume_on_read(self): out = self._make_one() out.feed_data(object(), 100) out.feed_data(object(), 100) out.feed_data(object(), 100) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) self.stream.transport.reset_mock() self.loop.run_until_complete(out.read()) self.assertFalse(self.stream.transport.resume_reading.called) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) def test_pause_on_read(self): out = self._make_one() @@ -228,12 +228,12 @@ def test_pause_on_read(self): out._buffer.append((object(), 100)) out._buffer.append((object(), 100)) out._size = 300 - self.stream.paused = False + self.stream._paused = False self.loop.run_until_complete(out.read()) self.assertTrue(self.stream.transport.pause_reading.called) - self.assertTrue(self.stream.paused) + self.assertTrue(self.stream._paused) def test_no_pause_on_read(self): item = object() diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 1c968fd9cb6..b41326e2e76 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -196,223 +196,115 @@ def setUp(self): def test_parse_eof_payload(self): out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_eof_payload(out, buf) - next(p) - p.send(b'data') - try: - p.throw(aiohttp.EofStream()) - except StopIteration: - pass + p = protocol.HttpPayloadParser(out, readall=True) + p.feed_data(b'data') + p.feed_eof() + self.assertTrue(out.is_eof()) self.assertEqual([(bytearray(b'data'), 4)], list(out._buffer)) def test_parse_length_payload(self): out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_length_payload(out, buf, 4) - next(p) - p.send(b'da') - p.send(b't') - try: - p.send(b'aline') - except StopIteration: - pass + p = protocol.HttpPayloadParser(out, length=4) + p.feed_data(b'da') + p.feed_data(b't') + eof, tail = p.feed_data(b'aline') self.assertEqual(3, len(out._buffer)) self.assertEqual(b'data', b''.join(d for d, _ in out._buffer)) - self.assertEqual(b'line', bytes(buf)) + self.assertEqual(b'line', tail) - def test_parse_length_payload_eof(self): + def _test_parse_length_payload_eof(self): out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_length_payload(out, buf, 4) - next(p) - p.send(b'da') - self.assertRaises(aiohttp.EofStream, p.throw, aiohttp.EofStream) + + p = protocol.HttpPayloadParser(None) + p.start(4, out) + + p.feed_data(b'da') + p.feed_eof() def test_parse_chunked_payload(self): out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_chunked_payload(out, buf) - next(p) - try: - p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') - except StopIteration: - pass + p = protocol.HttpPayloadParser(out, chunked=True) + eof, tail = p.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') self.assertEqual(b'dataline', b''.join(d for d, _ in out._buffer)) - self.assertEqual(b'', bytes(buf)) + self.assertEqual(b'', tail) + self.assertTrue(eof) + self.assertTrue(out.is_eof()) def test_parse_chunked_payload_chunks(self): out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_chunked_payload(out, buf) - next(p) - p.send(b'4\r\ndata\r') - p.send(b'\n4') - p.send(b'\r') - p.send(b'\n') - p.send(b'line\r\n0\r\n') - self.assertRaises(StopIteration, p.send, b'test\r\n') + p = protocol.HttpPayloadParser(out, chunked=True) + p.feed_data(b'4\r\ndata\r') + p.feed_data(b'\n4') + p.feed_data(b'\r') + p.feed_data(b'\n') + p.feed_data(b'line\r\n0\r\n') + eof, tail = p.feed_data(b'test\r\n') self.assertEqual(b'dataline', b''.join(d for d, _ in out._buffer)) + self.assertTrue(eof) - def test_parse_chunked_payload_incomplete(self): + def test_parse_chunked_payload_chunk_extension(self): out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_chunked_payload(out, buf) - next(p) - p.send(b'4\r\ndata\r\n') - self.assertRaises(aiohttp.EofStream, p.throw, aiohttp.EofStream) - - def test_parse_chunked_payload_extension(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_chunked_payload(out, buf) - next(p) - try: - p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') - except StopIteration: - pass + p = protocol.HttpPayloadParser(out, chunked=True) + eof, tail = p.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') self.assertEqual(b'dataline', b''.join(d for d, _ in out._buffer)) + self.assertTrue(eof) def test_parse_chunked_payload_size_error(self): out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(None).parse_chunked_payload(out, buf) - next(p) - self.assertRaises(errors.TransferEncodingError, p.send, b'blah\r\n') - - def test_http_payload_parser_length_broken(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('CONTENT-LENGTH', 'qwe')]), - [(b'CONTENT-LENGTH', b'qwe')], - None, None, False, False) - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg)(out, buf) - self.assertRaises(errors.InvalidHeader, next, p) - - def test_http_payload_parser_length_wrong(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('CONTENT-LENGTH', '-1')]), - [(b'CONTENT-LENGTH', b'-1')], - None, None, False, False) - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg)(out, buf) - self.assertRaises(errors.InvalidHeader, next, p) + p = protocol.HttpPayloadParser(out, chunked=True) + self.assertRaises( + errors.TransferEncodingError, p.feed_data, b'blah\r\n') + self.assertIsInstance(out.exception(), errors.TransferEncodingError) def test_http_payload_parser_length(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('CONTENT-LENGTH', '2')]), - [(b'CONTENT-LENGTH', b'2')], - None, None, False, False) out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg)(out, buf) - next(p) - try: - p.send(b'1245') - except StopIteration: - pass + p = protocol.HttpPayloadParser(out, length=2) + eof, tail = p.feed_data(b'1245') + self.assertTrue(eof) self.assertEqual(b'12', b''.join(d for d, _ in out._buffer)) - self.assertEqual(b'45', bytes(buf)) - - def test_http_payload_parser_no_length(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), CIMultiDict(), [], None, None, False, False) - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg, readall=False)(out, buf) - self.assertRaises(StopIteration, next, p) - self.assertEqual(b'', b''.join(out._buffer)) - self.assertTrue(out._eof) + self.assertEqual(b'45', tail) _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) def test_http_payload_parser_deflate(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('CONTENT-LENGTH', str(len(self._COMPRESSED)))]), - [(b'CONTENT-LENGTH', str(len(self._COMPRESSED)).encode('ascii'))], - None, 'deflate', False, False) - + length = len(self._COMPRESSED) out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg)(out, buf) - next(p) - self.assertRaises(StopIteration, p.send, self._COMPRESSED) + p = protocol.HttpPayloadParser( + out, length=length, compression='deflate') + p.feed_data(self._COMPRESSED) self.assertEqual(b'data', b''.join(d for d, _ in out._buffer)) - - def test_http_payload_parser_deflate_disabled(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('CONTENT-LENGTH', len(self._COMPRESSED))]), - [(b'CONTENT-LENGTH', str(len(self._COMPRESSED)).encode('ascii'))], - None, 'deflate', False, False) - - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg, compression=False)(out, buf) - next(p) - self.assertRaises(StopIteration, p.send, self._COMPRESSED) - self.assertEqual(self._COMPRESSED, b''.join(d for d, _ in out._buffer)) - - def test_http_payload_parser_websocket(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('SEC-WEBSOCKET-KEY1', '13')]), - [(b'SEC-WEBSOCKET-KEY1', b'13')], - None, None, False, False) - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg)(out, buf) - next(p) - self.assertRaises(StopIteration, p.send, b'1234567890') - self.assertEqual(b'12345678', b''.join(d for d, _ in out._buffer)) + self.assertTrue(out.is_eof()) def test_http_payload_parser_chunked(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('TRANSFER-ENCODING', 'chunked')]), - [(b'TRANSFER-ENCODING', b'chunked')], - None, None, False, True) out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg)(out, buf) - next(p) - self.assertRaises(StopIteration, p.send, - b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + parser = protocol.HttpPayloadParser(out, chunked=True) + assert not parser.done + + parser.feed_data(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') self.assertEqual(b'dataline', b''.join(d for d, _ in out._buffer)) + self.assertTrue(out.is_eof()) def test_http_payload_parser_eof(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), CIMultiDict(), [], None, None, False, False) out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg, readall=True)(out, buf) - next(p) - p.send(b'data') - p.send(b'line') - self.assertRaises(StopIteration, p.throw, aiohttp.EofStream()) + p = protocol.HttpPayloadParser(out, readall=True) + assert not p.done + + p.feed_data(b'data') + p.feed_data(b'line') + p.feed_eof() self.assertEqual(b'dataline', b''.join(d for d, _ in out._buffer)) + self.assertTrue(out.is_eof()) def test_http_payload_parser_length_zero(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), - CIMultiDict([('CONTENT-LENGTH', '0')]), - [(b'CONTENT-LENGTH', b'0')], - None, None, False, False) out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpPayloadParser(msg)(out, buf) - self.assertRaises(StopIteration, next, p) - self.assertEqual(b'', b''.join(out._buffer)) + p = protocol.HttpPayloadParser(out, length=0) + self.assertTrue(p.done) + self.assertTrue(out.is_eof()) class TestParseRequest(unittest.TestCase): @@ -421,43 +313,26 @@ def setUp(self): self.stream = mock.Mock() asyncio.set_event_loop(None) - def test_http_request_parser_max_headers(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser(8190, 20, 8190)(out, buf) - next(p) + def _test_http_request_parser_max_headers(self): + p = protocol.HttpRequestParser(8190, 20, 8190) self.assertRaises( errors.LineTooLong, - p.send, - b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') + p.parse_message, + b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n' + .split(b'\r\n')) def test_http_request_parser(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) - try: - p.send(b'get /path HTTP/1.1\r\n\r\n') - except StopIteration: - pass - result = out._buffer[0][0] + p = protocol.HttpRequestParser() + result = p.parse_message(b'get /path HTTP/1.1\r\n\r\n'.split(b'\r\n')) self.assertEqual( ('GET', '/path', (1, 1), CIMultiDict(), [], False, None, False, False), result) def test_http_request_parser_utf8(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) + p = protocol.HttpRequestParser() msg = 'get /path HTTP/1.1\r\nx-test:тест\r\n\r\n'.encode('utf-8') - try: - p.send(msg) - except StopIteration: - pass - result, length = out._buffer[0] - self.assertEqual(len(msg), length) + result = p.parse_message(msg.split(b'\r\n')) self.assertEqual( ('GET', '/path', (1, 1), CIMultiDict([('X-TEST', 'тест')]), @@ -466,17 +341,9 @@ def test_http_request_parser_utf8(self): result) def test_http_request_parser_non_utf8(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) + p = protocol.HttpRequestParser() msg = 'get /path HTTP/1.1\r\nx-test:тест\r\n\r\n'.encode('cp1251') - try: - p.send(msg) - except StopIteration: - pass - result, length = out._buffer[0] - self.assertEqual(len(msg), length) + result = p.parse_message(msg.split(b'\r\n')) self.assertEqual( ('GET', '/path', (1, 1), CIMultiDict([('X-TEST', 'тест'.encode('cp1251').decode( @@ -487,56 +354,34 @@ def test_http_request_parser_non_utf8(self): def test_http_request_parser_eof(self): # HttpRequestParser does fail on EofStream() - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) - p.send(b'get /path HTTP/1.1\r\n') - try: - p.throw(aiohttp.EofStream()) - except aiohttp.EofStream: - pass - self.assertFalse(out._buffer) + p = protocol.HttpRequestParser() + p.parse_message(b'get /path HTTP/1.1\r\n'.split(b'\r\n')) def test_http_request_parser_two_slashes(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) - try: - p.send(b'get //path HTTP/1.1\r\n\r\n') - except StopIteration: - pass + p = protocol.HttpRequestParser() + result = p.parse_message( + b'get //path HTTP/1.1\r\n\r\n'.split(b'\r\n')) self.assertEqual( ('GET', '//path', (1, 1), CIMultiDict(), [], False, None, False, False), - out._buffer[0][0]) + result) def test_http_request_parser_bad_status_line(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) + p = protocol.HttpRequestParser() self.assertRaises( - errors.BadStatusLine, p.send, b'\r\n\r\n') + errors.BadStatusLine, p.parse_message, b'\r\n\r\n'.split(b'\r\n')) def test_http_request_parser_bad_method(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) + p = protocol.HttpRequestParser() self.assertRaises( - errors.BadStatusLine, - p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + errors.BadStatusLine, p.parse_message, + b'!12%()+=~$ /get HTTP/1.1\r\n\r\n'.split(b'\r\n')) def test_http_request_parser_bad_version(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpRequestParser()(out, buf) - next(p) + p = protocol.HttpRequestParser() self.assertRaises( errors.BadStatusLine, - p.send, b'GET //get HT/11\r\n\r\n') + p.parse_message, b'GET //get HT/11\r\n\r\n'.split(b'\r\n')) class TestParseResponse(unittest.TestCase): @@ -546,99 +391,59 @@ def setUp(self): asyncio.set_event_loop(None) def test_http_response_parser_utf8(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) + p = protocol.HttpResponseParser() msg = 'HTTP/1.1 200 Ok\r\nx-test:тест\r\n\r\n'.encode('utf-8') - try: - p.send(msg) - except StopIteration: - pass - v, s, r, h = out._buffer[0][0][:4] - self.assertEqual(v, (1, 1)) - self.assertEqual(s, 200) - self.assertEqual(r, 'Ok') - self.assertEqual(h, CIMultiDict([('X-TEST', 'тест')])) + result = p.parse_message(msg.split(b'\r\n')) + self.assertEqual(result.version, (1, 1)) + self.assertEqual(result.code, 200) + self.assertEqual(result.reason, 'Ok') + self.assertEqual(result.headers, CIMultiDict([('X-TEST', 'тест')])) def test_http_response_parser_bad_status_line(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) - self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') + p = protocol.HttpResponseParser() + self.assertRaises( + errors.BadStatusLine, p.parse_message, b'\r\n\r\n'.split(b'\r\n')) - def test_http_response_parser_bad_status_line_too_long(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() + def _test_http_response_parser_bad_status_line_too_long(self): p = protocol.HttpResponseParser( - max_headers=2, max_line_size=2)(out, buf) - next(p) + max_headers=2, max_line_size=2) self.assertRaises( - errors.LineTooLong, p.send, b'HTTP/1.1 200 Ok\r\n\r\n') - - def test_http_response_parser_bad_status_line_eof(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) - self.assertRaises(aiohttp.EofStream, p.throw, aiohttp.EofStream()) + errors.LineTooLong, + p.parse_message, b'HTTP/1.1 200 Ok\r\n\r\n'.split(b'\r\n')) def test_http_response_parser_bad_version(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) + p = protocol.HttpResponseParser() with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HT/11 200 Ok\r\n\r\n') + p.parse_message(b'HT/11 200 Ok\r\n\r\n'.split(b'\r\n')) self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) def test_http_response_parser_no_reason(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) - try: - p.send(b'HTTP/1.1 200\r\n\r\n') - except StopIteration: - pass - v, s, r = out._buffer[0][0][:3] - self.assertEqual(v, (1, 1)) - self.assertEqual(s, 200) - self.assertEqual(r, '') + p = protocol.HttpResponseParser() + result = p.parse_message(b'HTTP/1.1 200\r\n\r\n'.split(b'\r\n')) + self.assertEqual(result.version, (1, 1)) + self.assertEqual(result.code, 200) + self.assertEqual(result.reason, '') def test_http_response_parser_bad(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) + p = protocol.HttpResponseParser() with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTT/1\r\n\r\n') + p.parse_message(b'HTT/1\r\n\r\n'.split(b'\r\n')) self.assertIn('HTT/1', str(cm.exception)) def test_http_response_parser_code_under_100(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) + p = protocol.HttpResponseParser() with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTTP/1.1 99 test\r\n\r\n') + p.parse_message(b'HTTP/1.1 99 test\r\n\r\n'.split(b'\r\n')) self.assertIn('HTTP/1.1 99 test', str(cm.exception)) def test_http_response_parser_code_above_999(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) + p = protocol.HttpResponseParser() with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTTP/1.1 9999 test\r\n\r\n') + p.parse_message(b'HTTP/1.1 9999 test\r\n\r\n'.split(b'\r\n')) self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) def test_http_response_parser_code_not_int(self): - out = aiohttp.FlowControlDataQueue(self.stream) - buf = aiohttp.ParserBuffer() - p = protocol.HttpResponseParser()(out, buf) - next(p) + p = protocol.HttpResponseParser() with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTTP/1.1 ttt test\r\n\r\n') + p.parse_message(b'HTTP/1.1 ttt test\r\n\r\n'.split(b'\r\n')) self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/test_parser_buffer.py b/tests/test_parser_buffer.py deleted file mode 100644 index cbff6df6c8b..00000000000 --- a/tests/test_parser_buffer.py +++ /dev/null @@ -1,252 +0,0 @@ -from unittest import mock - -import pytest - -from aiohttp import errors, parsers - - -@pytest.fixture -def stream(): - return mock.Mock() - - -@pytest.fixture -def buf(): - return parsers.ParserBuffer() - - -def test_feed_data(buf): - buf.feed_data(b'') - assert len(buf) == 0 - - buf.feed_data(b'data') - assert len(buf) == 4 - assert bytes(buf), b'data' - - -def test_feed_data_after_exception(buf): - buf.feed_data(b'data') - - exc = ValueError() - buf.set_exception(exc) - buf.feed_data(b'more') - assert len(buf) == 4 - assert bytes(buf) == b'data' - - -def test_read_exc(buf): - p = buf.read(3) - next(p) - p.send(b'1') - - exc = ValueError() - buf.set_exception(exc) - assert buf.exception() is exc - with pytest.raises(ValueError): - p.send(b'1') - - -def test_read_exc_multiple(buf): - p = buf.read(3) - next(p) - p.send(b'1') - - exc = ValueError() - buf.set_exception(exc) - assert buf.exception() is exc - - p = buf.read(3) - with pytest.raises(ValueError): - next(p) - - -def test_read(buf): - p = buf.read(3) - next(p) - p.send(b'1') - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - - assert res == b'123' - assert b'4' == bytes(buf) - - -def test_readsome(buf): - p = buf.readsome(3) - next(p) - try: - p.send(b'1') - except StopIteration as exc: - res = exc.value - assert res == b'1' - - p = buf.readsome(2) - next(p) - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - assert res == b'23' - assert b'4' == bytes(buf) - - -def test_readsome_exc(buf): - buf.set_exception(ValueError()) - - p = buf.readsome(3) - with pytest.raises(ValueError): - next(p) - - -def test_wait(buf): - p = buf.wait(3) - next(p) - p.send(b'1') - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - - assert res == b'123' - assert b'1234' == bytes(buf) - - -def test_wait_exc(buf): - buf.set_exception(ValueError()) - - p = buf.wait(3) - with pytest.raises(ValueError): - next(p) - - -def test_skip(buf): - p = buf.skip(3) - next(p) - p.send(b'1') - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - - assert res is None - assert b'4' == bytes(buf) - - -def test_skip_exc(buf): - buf.set_exception(ValueError()) - p = buf.skip(3) - with pytest.raises(ValueError): - next(p) - - -def test_readuntil_limit(buf): - p = buf.readuntil(b'\n', 4) - next(p) - p.send(b'1') - p.send(b'234') - with pytest.raises(errors.LineLimitExceededParserError): - p.send(b'5') - - -def test_readuntil_limit2(buf): - p = buf.readuntil(b'\n', 4) - next(p) - with pytest.raises(errors.LineLimitExceededParserError): - p.send(b'12345\n6') - - -def test_readuntil_limit3(buf): - p = buf.readuntil(b'\n', 4) - next(p) - with pytest.raises(errors.LineLimitExceededParserError): - p.send(b'12345\n6') - - -def test_readuntil(buf): - p = buf.readuntil(b'\n', 4) - next(p) - p.send(b'123') - try: - p.send(b'\n456') - except StopIteration as exc: - res = exc.value - - assert res == b'123\n' - assert b'456' == bytes(buf) - - -def test_readuntil_exc(buf): - buf.set_exception(ValueError()) - p = buf.readuntil(b'\n', 4) - with pytest.raises(ValueError): - next(p) - - -def test_waituntil_limit(buf): - p = buf.waituntil(b'\n', 4) - next(p) - p.send(b'1') - p.send(b'234') - with pytest.raises(errors.LineLimitExceededParserError): - p.send(b'5') - - -def test_waituntil_limit2(buf): - p = buf.waituntil(b'\n', 4) - next(p) - with pytest.raises(errors.LineLimitExceededParserError): - p.send(b'12345\n6') - - -def test_waituntil_limit3(buf): - p = buf.waituntil(b'\n', 4) - next(p) - with pytest.raises(errors.LineLimitExceededParserError): - p.send(b'12345\n6') - - -def test_waituntil(buf): - p = buf.waituntil(b'\n', 4) - next(p) - p.send(b'123') - try: - p.send(b'\n456') - except StopIteration as exc: - res = exc.value - - assert res == b'123\n' - assert b'123\n456' == bytes(buf) - - -def test_waituntil_exc(buf): - buf.set_exception(ValueError()) - p = buf.waituntil(b'\n', 4) - with pytest.raises(ValueError): - next(p) - - -def test_skipuntil(buf): - p = buf.skipuntil(b'\n') - next(p) - p.send(b'123') - try: - p.send(b'\n456\n') - except StopIteration: - pass - assert b'456\n' == bytes(buf) - - p = buf.skipuntil(b'\n') - try: - next(p) - except StopIteration: - pass - assert b'' == bytes(buf) - - -def test_skipuntil_exc(buf): - buf.set_exception(ValueError()) - p = buf.skipuntil(b'\n') - with pytest.raises(ValueError): - next(p) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index fa219ed0dda..5a0e411e376 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -10,42 +10,42 @@ @pytest.fixture -def transport(): - transport = mock.Mock() +def stream(): + stream = mock.Mock() def acquire(cb): - cb(transport) + cb(stream) - transport.acquire = acquire - transport.drain.return_value = () - return transport + stream.acquire = acquire + stream.drain.return_value = () + return stream compressor = zlib.compressobj(wbits=-zlib.MAX_WBITS) COMPRESSED = b''.join([compressor.compress(b'data'), compressor.flush()]) -def test_start_request(transport): +def test_start_request(stream): msg = protocol.Request( - transport, 'GET', '/index.html', close=True) + stream, 'GET', '/index.html', close=True) - assert msg._transport is transport + assert msg._transport is stream.transport assert msg.closing assert msg.status_line == 'GET /index.html HTTP/1.1\r\n' -def test_start_response(transport): - msg = protocol.Response(transport, 200, close=True) +def test_start_response(stream): + msg = protocol.Response(stream, 200, close=True) - assert msg._transport is transport + assert msg._transport is stream.transport assert msg.status == 200 assert msg.reason == "OK" assert msg.closing assert msg.status_line == 'HTTP/1.1 200 OK\r\n' -def test_start_response_with_reason(transport): - msg = protocol.Response(transport, 333, close=True, +def test_start_response_with_reason(stream): + msg = protocol.Response(stream, 333, close=True, reason="My Reason") assert msg.status == 333 @@ -53,30 +53,30 @@ def test_start_response_with_reason(transport): assert msg.status_line == 'HTTP/1.1 333 My Reason\r\n' -def test_start_response_with_unknown_reason(transport): - msg = protocol.Response(transport, 777, close=True) +def test_start_response_with_unknown_reason(stream): + msg = protocol.Response(stream, 777, close=True) assert msg.status == 777 assert msg.reason == "777" assert msg.status_line == 'HTTP/1.1 777 777\r\n' -def test_force_close(transport): - msg = protocol.Response(transport, 200) +def test_force_close(stream): + msg = protocol.Response(stream, 200) assert not msg.closing msg.force_close() assert msg.closing -def test_force_chunked(transport): - msg = protocol.Response(transport, 200) +def test_force_chunked(stream): + msg = protocol.Response(stream, 200) assert not msg.chunked msg.enable_chunking() assert msg.chunked -def test_keep_alive(transport): - msg = protocol.Response(transport, 200, close=True) +def test_keep_alive(stream): + msg = protocol.Response(stream, 200, close=True) assert not msg.keep_alive() msg.keepalive = True assert msg.keep_alive() @@ -85,41 +85,41 @@ def test_keep_alive(transport): assert not msg.keep_alive() -def test_keep_alive_http10(transport): - msg = protocol.Response(transport, 200, http_version=(1, 0)) +def test_keep_alive_http10(stream): + msg = protocol.Response(stream, 200, http_version=(1, 0)) assert not msg.keepalive assert not msg.keep_alive() - msg = protocol.Response(transport, 200, http_version=(1, 1)) + msg = protocol.Response(stream, 200, http_version=(1, 1)) assert msg.keepalive is None -def test_add_header(transport): - msg = protocol.Response(transport, 200) +def test_add_header(stream): + msg = protocol.Response(stream, 200) assert [] == list(msg.headers) msg.add_header('content-type', 'plain/html') assert [('Content-Type', 'plain/html')] == list(msg.headers.items()) -def test_add_header_with_spaces(transport): - msg = protocol.Response(transport, 200) +def test_add_header_with_spaces(stream): + msg = protocol.Response(stream, 200) assert [] == list(msg.headers) msg.add_header('content-type', ' plain/html ') assert [('Content-Type', 'plain/html')] == list(msg.headers.items()) -def test_add_header_non_ascii(transport): - msg = protocol.Response(transport, 200) +def test_add_header_non_ascii(stream): + msg = protocol.Response(stream, 200) assert [] == list(msg.headers) with pytest.raises(AssertionError): msg.add_header('тип-контента', 'текст/плейн') -def test_add_header_invalid_value_type(transport): - msg = protocol.Response(transport, 200) +def test_add_header_invalid_value_type(stream): + msg = protocol.Response(stream, 200) assert [] == list(msg.headers) with pytest.raises(AssertionError): @@ -129,44 +129,44 @@ def test_add_header_invalid_value_type(transport): msg.add_header(list('content-type'), 'text/plain') -def test_add_headers(transport): - msg = protocol.Response(transport, 200) +def test_add_headers(stream): + msg = protocol.Response(stream, 200) assert [] == list(msg.headers) msg.add_headers(('content-type', 'plain/html')) assert [('Content-Type', 'plain/html')] == list(msg.headers.items()) -def test_add_headers_length(transport): - msg = protocol.Response(transport, 200) +def test_add_headers_length(stream): + msg = protocol.Response(stream, 200) assert msg.length is None msg.add_headers(('content-length', '42')) assert 42 == msg.length -def test_add_headers_upgrade(transport): - msg = protocol.Response(transport, 200) +def test_add_headers_upgrade(stream): + msg = protocol.Response(stream, 200) assert not msg.upgrade msg.add_headers(('connection', 'upgrade')) assert msg.upgrade -def test_add_headers_upgrade_websocket(transport): - msg = protocol.Response(transport, 200) +def test_add_headers_upgrade_websocket(stream): + msg = protocol.Response(stream, 200) msg.add_headers(('upgrade', 'test')) assert not msg.websocket assert [('Upgrade', 'test')] == list(msg.headers.items()) - msg = protocol.Response(transport, 200) + msg = protocol.Response(stream, 200) msg.add_headers(('upgrade', 'websocket')) assert msg.websocket assert [('Upgrade', 'websocket')] == list(msg.headers.items()) -def test_add_headers_connection_keepalive(transport): - msg = protocol.Response(transport, 200) +def test_add_headers_connection_keepalive(stream): + msg = protocol.Response(stream, 200) msg.add_headers(('connection', 'keep-alive')) assert [] == list(msg.headers) @@ -176,16 +176,16 @@ def test_add_headers_connection_keepalive(transport): assert not msg.keepalive -def test_add_headers_hop_headers(transport): - msg = protocol.Response(transport, 200) +def test_add_headers_hop_headers(stream): + msg = protocol.Response(stream, 200) msg.HOP_HEADERS = (hdrs.TRANSFER_ENCODING,) msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) assert [] == list(msg.headers) -def test_default_headers_http_10(transport): - msg = protocol.Response(transport, 200, +def test_default_headers_http_10(stream): + msg = protocol.Response(stream, 200, http_version=protocol.HttpVersion10) msg._add_default_headers() @@ -193,52 +193,52 @@ def test_default_headers_http_10(transport): assert 'keep-alive' == msg.headers['CONNECTION'] -def test_default_headers_http_11(transport): - msg = protocol.Response(transport, 200) +def test_default_headers_http_11(stream): + msg = protocol.Response(stream, 200) msg._add_default_headers() assert 'DATE' in msg.headers assert 'CONNECTION' not in msg.headers -def test_default_headers_server(transport): - msg = protocol.Response(transport, 200) +def test_default_headers_server(stream): + msg = protocol.Response(stream, 200) msg._add_default_headers() assert 'SERVER' in msg.headers -def test_default_headers_chunked(transport): - msg = protocol.Response(transport, 200) +def test_default_headers_chunked(stream): + msg = protocol.Response(stream, 200) msg._add_default_headers() assert 'TRANSFER-ENCODING' not in msg.headers - msg = protocol.Response(transport, 200) + msg = protocol.Response(stream, 200) msg.enable_chunking() msg.send_headers() assert 'TRANSFER-ENCODING' in msg.headers -def test_default_headers_connection_upgrade(transport): - msg = protocol.Response(transport, 200) +def test_default_headers_connection_upgrade(stream): + msg = protocol.Response(stream, 200) msg.upgrade = True msg._add_default_headers() assert msg.headers['Connection'] == 'Upgrade' -def test_default_headers_connection_close(transport): - msg = protocol.Response(transport, 200) +def test_default_headers_connection_close(stream): + msg = protocol.Response(stream, 200) msg.force_close() msg._add_default_headers() assert msg.headers['Connection'] == 'close' -def test_default_headers_connection_keep_alive_http_10(transport): - msg = protocol.Response(transport, 200, +def test_default_headers_connection_keep_alive_http_10(stream): + msg = protocol.Response(stream, 200, http_version=protocol.HttpVersion10) msg.keepalive = True msg._add_default_headers() @@ -246,8 +246,8 @@ def test_default_headers_connection_keep_alive_http_10(transport): assert msg.headers['Connection'] == 'keep-alive' -def test_default_headers_connection_keep_alive_11(transport): - msg = protocol.Response(transport, 200, +def test_default_headers_connection_keep_alive_11(stream): + msg = protocol.Response(stream, 200, http_version=protocol.HttpVersion11) msg.keepalive = True msg._add_default_headers() @@ -255,8 +255,8 @@ def test_default_headers_connection_keep_alive_11(transport): assert 'Connection' not in msg.headers -def test_send_headers(transport): - msg = protocol.Response(transport, 200) +def test_send_headers(stream): + msg = protocol.Response(stream, 200) msg.add_headers(('content-type', 'plain/html')) assert not msg.is_headers_sent() @@ -269,8 +269,8 @@ def test_send_headers(transport): assert msg.is_headers_sent() -def test_send_headers_non_ascii(transport): - msg = protocol.Response(transport, 200) +def test_send_headers_non_ascii(stream): + msg = protocol.Response(stream, 200) msg.add_headers(('x-header', 'текст')) assert not msg.is_headers_sent() @@ -284,8 +284,8 @@ def test_send_headers_non_ascii(transport): assert msg.is_headers_sent() -def test_send_headers_nomore_add(transport): - msg = protocol.Response(transport, 200) +def test_send_headers_nomore_add(stream): + msg = protocol.Response(stream, 200) msg.add_headers(('content-type', 'plain/html')) msg.send_headers() @@ -293,44 +293,44 @@ def test_send_headers_nomore_add(transport): msg.add_header('content-type', 'plain/html') -def test_prepare_length(transport): - msg = protocol.Response(transport, 200) +def test_prepare_length(stream): + msg = protocol.Response(stream, 200) msg.add_headers(('content-length', '42')) msg.send_headers() assert msg.length == 42 -def test_prepare_chunked_force(transport): - msg = protocol.Response(transport, 200) +def test_prepare_chunked_force(stream): + msg = protocol.Response(stream, 200) msg.enable_chunking() msg.add_headers(('content-length', '42')) msg.send_headers() assert msg.chunked -def test_prepare_chunked_no_length(transport): - msg = protocol.Response(transport, 200) +def test_prepare_chunked_no_length(stream): + msg = protocol.Response(stream, 200) msg.send_headers() assert msg.chunked -def test_prepare_eof(transport): - msg = protocol.Response(transport, 200, http_version=(1, 0)) +def test_prepare_eof(stream): + msg = protocol.Response(stream, 200, http_version=(1, 0)) msg.send_headers() assert msg.length is None -def test_write_auto_send_headers(transport): - msg = protocol.Response(transport, 200, http_version=(1, 0)) +def test_write_auto_send_headers(stream): + msg = protocol.Response(stream, 200, http_version=(1, 0)) msg.send_headers() msg.write(b'data1') assert msg.headers_sent -def test_write_payload_eof(transport): - write = transport.write = mock.Mock() - msg = protocol.Response(transport, 200, http_version=(1, 0)) +def test_write_payload_eof(stream): + write = stream.transport.write = mock.Mock() + msg = protocol.Response(stream, 200, http_version=(1, 0)) msg.send_headers() msg.write(b'data1') @@ -344,10 +344,10 @@ def test_write_payload_eof(transport): @asyncio.coroutine -def test_write_payload_chunked(transport, loop): - write = transport.write = mock.Mock() +def test_write_payload_chunked(stream, loop): + write = stream.transport.write = mock.Mock() - msg = protocol.Response(transport, 200, loop=loop) + msg = protocol.Response(stream, 200, loop=loop) msg.enable_chunking() msg.send_headers() @@ -359,10 +359,10 @@ def test_write_payload_chunked(transport, loop): @asyncio.coroutine -def test_write_payload_chunked_multiple(transport, loop): - write = transport.write = mock.Mock() +def test_write_payload_chunked_multiple(stream, loop): + write = stream.transport.write = mock.Mock() - msg = protocol.Response(transport, 200, loop=loop) + msg = protocol.Response(stream, 200, loop=loop) msg.enable_chunking() msg.send_headers() @@ -376,10 +376,10 @@ def test_write_payload_chunked_multiple(transport, loop): @asyncio.coroutine -def test_write_payload_length(transport, loop): - write = transport.write = mock.Mock() +def test_write_payload_length(stream, loop): + write = stream.transport.write = mock.Mock() - msg = protocol.Response(transport, 200, loop=loop) + msg = protocol.Response(stream, 200, loop=loop) msg.add_headers(('content-length', '2')) msg.send_headers() @@ -392,10 +392,10 @@ def test_write_payload_length(transport, loop): @asyncio.coroutine -def test_write_payload_chunked_filter(transport, loop): - write = transport.write = mock.Mock() +def test_write_payload_chunked_filter(stream, loop): + write = stream.transport.write = mock.Mock() - msg = protocol.Response(transport, 200, loop=loop) + msg = protocol.Response(stream, 200, loop=loop) msg.send_headers() msg.enable_chunking() @@ -408,9 +408,9 @@ def test_write_payload_chunked_filter(transport, loop): @asyncio.coroutine -def test_write_payload_chunked_filter_mutiple_chunks(transport, loop): - write = transport.write = mock.Mock() - msg = protocol.Response(transport, 200, loop=loop) +def test_write_payload_chunked_filter_mutiple_chunks(stream, loop): + write = stream.transport.write = mock.Mock() + msg = protocol.Response(stream, 200, loop=loop) msg.send_headers() msg.enable_chunking() @@ -427,9 +427,9 @@ def test_write_payload_chunked_filter_mutiple_chunks(transport, loop): @asyncio.coroutine -def test_write_payload_deflate_compression(transport, loop): - write = transport.write = mock.Mock() - msg = protocol.Response(transport, 200, loop=loop) +def test_write_payload_deflate_compression(stream, loop): + write = stream.transport.write = mock.Mock() + msg = protocol.Response(stream, 200, loop=loop) msg.add_headers(('content-length', '{}'.format(len(COMPRESSED)))) msg.send_headers() @@ -440,14 +440,13 @@ def test_write_payload_deflate_compression(transport, loop): chunks = [c[1][0] for c in list(write.mock_calls)] assert all(chunks) content = b''.join(chunks) - print(content) assert COMPRESSED == content.split(b'\r\n\r\n', 1)[-1] @asyncio.coroutine -def test_write_payload_deflate_and_chunked(transport, loop): - write = transport.write = mock.Mock() - msg = protocol.Response(transport, 200, loop=loop) +def test_write_payload_deflate_and_chunked(stream, loop): + write = stream.transport.write = mock.Mock() + msg = protocol.Response(stream, 200, loop=loop) msg.send_headers() msg.enable_compression('deflate') @@ -464,8 +463,8 @@ def test_write_payload_deflate_and_chunked(transport, loop): content.split(b'\r\n\r\n', 1)[-1]) -def test_write_drain(transport, loop): - msg = protocol.Response(transport, 200, http_version=(1, 0), loop=loop) +def test_write_drain(stream, loop): + msg = protocol.Response(stream, 200, http_version=(1, 0), loop=loop) msg.drain = mock.Mock() msg.send_headers() msg.write(b'1' * (64 * 1024 * 2), drain=False) @@ -476,16 +475,16 @@ def test_write_drain(transport, loop): assert msg.buffer_size == 0 -def test_dont_override_request_headers_with_default_values(transport, loop): +def test_dont_override_request_headers_with_default_values(stream, loop): msg = protocol.Request( - transport, 'GET', '/index.html', close=True, loop=loop) + stream, 'GET', '/index.html', close=True, loop=loop) msg.add_header('USER-AGENT', 'custom') msg._add_default_headers() assert 'custom' == msg.headers['USER-AGENT'] -def test_dont_override_response_headers_with_default_values(transport, loop): - msg = protocol.Response(transport, 200, http_version=(1, 0), loop=loop) +def test_dont_override_response_headers_with_default_values(stream, loop): + msg = protocol.Response(stream, 200, http_version=(1, 0), loop=loop) msg.add_header('DATE', 'now') msg.add_header('SERVER', 'custom') msg._add_default_headers() diff --git a/tests/test_py35/test_client.py b/tests/test_py35/test_client.py index 066b9bb38ff..b23964d0256 100644 --- a/tests/test_py35/test_client.py +++ b/tests/test_py35/test_client.py @@ -1,3 +1,5 @@ +import asyncio + import pytest import aiohttp @@ -16,7 +18,7 @@ async def handler(request): resp = web.StreamResponse(headers={'content-length': '100'}) await resp.prepare(request) await resp.drain() - await asycio.sleep(0.1, loop=request.app.loop) + await asyncio.sleep(0.1, loop=request.app.loop) return resp app = web.Application(loop=loop) @@ -52,7 +54,7 @@ async def handler(request): resp = web.StreamResponse(headers={'content-length': '100'}) await resp.prepare(request) await resp.drain() - await asycio.sleep(0.1, loop=request.app.loop) + await asyncio.sleep(0.1, loop=request.app.loop) return resp app = web.Application(loop=loop) diff --git a/tests/test_server.py b/tests/test_server.py index 4885d4ba0ca..91ff6867048 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -56,6 +56,7 @@ def write(chunk): transport.acquire.side_effect = acquire transport.write.side_effect = write + transport.transport.write.side_effect = write transport.drain.return_value = () return (transport, buf) @@ -79,7 +80,7 @@ def test_handle_request(srv, writer): yield from srv.handle_request(message, mock.Mock()) content = b''.join( - [c[1][0] for c in list(srv.writer.write.mock_calls)]) + [c[1][0] for c in list(srv.writer.transport.write.mock_calls)]) assert content.startswith(b'HTTP/1.1 404 Not Found\r\n') @@ -174,7 +175,7 @@ def test_data_received(srv): def test_eof_received(srv): srv.connection_made(mock.Mock()) srv.eof_received() - assert srv.reader._eof + # assert srv.reader._eof @asyncio.coroutine @@ -263,7 +264,7 @@ def test_handle_error(srv, writer): yield from srv.handle_error(404, headers=(('X-Server', 'asyncio'),)) content = b''.join( - [c[1][0] for c in list(srv.writer.write.mock_calls)]) + [c[1][0] for c in list(srv.writer.transport.write.mock_calls)]) assert b'HTTP/1.1 404 Not Found' in content assert b'X-Server: asyncio' in content assert not srv._keepalive @@ -284,7 +285,7 @@ def test_handle_error__utf(make_srv, writer): yield from srv.handle_error(exc=exc) content = b''.join( - [c[1][0] for c in list(srv.writer.write.mock_calls)]) + [c[1][0] for c in list(srv.writer.transport.write.mock_calls)]) assert b'HTTP/1.1 500 Internal Server Error' in content assert b'Content-Type: text/html; charset=utf-8' in content pattern = escape("raise RuntimeError('что-то пошло не так')") @@ -298,10 +299,10 @@ def test_handle_error__utf(make_srv, writer): def test_handle_error_traceback_exc(make_srv, transport): log = mock.Mock() srv = make_srv(debug=True, logger=log) - transport, buf = transport - srv.transport = transport + stream, buf = transport + srv.transport = stream srv.transport.get_extra_info.return_value = '127.0.0.1' - srv.writer = transport + srv.writer = stream srv._request_handlers.append(mock.Mock()) with mock.patch('aiohttp.server.traceback') as m_trace: @@ -326,7 +327,7 @@ def test_handle_error_debug(srv, writer): yield from srv.handle_error(999, exc=exc) content = b''.join( - [c[1][0] for c in list(srv.writer.write.mock_calls)]) + [c[1][0] for c in list(srv.writer.transport.write.mock_calls)]) assert b'HTTP/1.1 500 Internal' in content assert b'Traceback (most recent call last):' in content diff --git a/tests/test_stream_parser.py b/tests/test_stream_parser.py deleted file mode 100644 index 8aa286f88b3..00000000000 --- a/tests/test_stream_parser.py +++ /dev/null @@ -1,373 +0,0 @@ -"""Tests for parsers.py""" - -from unittest import mock - -import pytest - -from aiohttp import parsers - -DATA = b'line1\nline2\nline3\n' - - -class LinesParser: - """Lines parser. - Lines parser splits a bytes stream into a chunks of data, each chunk ends - with \\n symbol.""" - - def __init__(self): - pass - - def __call__(self, out, buf): - try: - while True: - chunk = yield from buf.readuntil(b'\n', 0xffff) - out.feed_data(chunk, len(chunk)) - except parsers.EofStream: - pass - - -@pytest.fixture -def lines_parser(): - return LinesParser() - - -def test_at_eof(loop): - proto = parsers.StreamParser(loop=loop) - assert not proto.at_eof() - - proto.feed_eof() - assert proto.at_eof() - - -def test_exception(loop): - stream = parsers.StreamParser(loop=loop) - assert stream.exception() is None - - exc = ValueError() - stream.set_exception(exc) - assert stream.exception() is exc - - -def test_exception_connection_error(loop): - stream = parsers.StreamParser(loop=loop) - assert stream.exception() is None - - exc = ConnectionError() - stream.set_exception(exc) - assert stream.exception() is not exc - assert isinstance(stream.exception(), RuntimeError) - assert stream.exception().__cause__ is exc - assert stream.exception().__context__ is exc - - -def test_exception_waiter(loop, lines_parser): - - stream = parsers.StreamParser(loop=loop) - - stream._parser = lines_parser - buf = stream._output = parsers.FlowControlDataQueue( - stream, loop=loop) - - exc = ValueError() - stream.set_exception(exc) - assert buf.exception() is exc - - -def test_feed_data(loop): - stream = parsers.StreamParser(loop=loop) - - stream.feed_data(DATA) - assert DATA == bytes(stream._buffer) - - -def test_feed_none_data(loop): - stream = parsers.StreamParser(loop=loop) - - stream.feed_data(None) - assert b'' == bytes(stream._buffer) - - -def test_set_parser_unset_prev(loop, lines_parser): - stream = parsers.StreamParser(loop=loop) - stream.set_parser(lines_parser) - - unset = stream.unset_parser = mock.Mock() - stream.set_parser(lines_parser) - - assert unset.called - - -def test_set_parser_exception(loop, lines_parser): - stream = parsers.StreamParser(loop=loop) - - exc = ValueError() - stream.set_exception(exc) - s = stream.set_parser(lines_parser) - assert s.exception() is exc - - -def test_set_parser_feed_existing(loop, lines_parser): - stream = parsers.StreamParser(loop=loop) - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - s = stream.set_parser(lines_parser) - - assert ([(bytearray(b'line1\r\n'), 7), (bytearray(b'line2\r\n'), 7)] == - list(s._buffer)) - assert b'data' == bytes(stream._buffer) - assert stream._parser is not None - - stream.unset_parser() - assert stream._parser is None - assert b'data' == bytes(stream._buffer) - assert s._eof - - -def test_set_parser_feed_existing_exc(loop): - def p(out, buf): - yield from buf.read(1) - raise ValueError() - - stream = parsers.StreamParser(loop=loop) - stream.feed_data(b'line1') - s = stream.set_parser(p) - assert isinstance(s.exception(), ValueError) - - -def test_set_parser_feed_existing_eof(loop, lines_parser): - stream = parsers.StreamParser(loop=loop) - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - stream.feed_eof() - s = stream.set_parser(lines_parser) - - assert ([(bytearray(b'line1\r\n'), 7), (bytearray(b'line2\r\n'), 7)] == - list(s._buffer)) - assert b'data' == bytes(stream._buffer) - assert stream._parser is None - - -def test_set_parser_feed_existing_eof_exc(loop): - def p(out, buf): - try: - while True: - yield # read chunk - except parsers.EofStream: - raise ValueError() - - stream = parsers.StreamParser(loop=loop) - stream.feed_data(b'line1') - stream.feed_eof() - s = stream.set_parser(p) - assert isinstance(s.exception(), ValueError) - - -def test_set_parser_feed_existing_eof_unhandled_eof(loop): - def p(out, buf): - while True: - yield # read chunk - - stream = parsers.StreamParser(loop=loop) - stream.feed_data(b'line1') - stream.feed_eof() - s = stream.set_parser(p) - assert not s.is_eof() - assert isinstance(s.exception(), RuntimeError) - - -def test_set_parser_unset(loop, lines_parser): - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(lines_parser) - - stream.feed_data(b'line1\r\nline2\r\n') - assert ([(bytearray(b'line1\r\n'), 7), (bytearray(b'line2\r\n'), 7)] == - list(s._buffer)) - assert b'' == bytes(stream._buffer) - stream.unset_parser() - assert s._eof - assert b'' == bytes(stream._buffer) - - -def test_set_parser_feed_existing_stop(loop): - def LinesParser(out, buf): - try: - chunk = yield from buf.readuntil(b'\n') - out.feed_data(chunk, len(chunk)) - - chunk = yield from buf.readuntil(b'\n') - out.feed_data(chunk, len(chunk)) - finally: - out.feed_eof() - - stream = parsers.StreamParser(loop=loop) - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - s = stream.set_parser(LinesParser) - - assert b'line1\r\nline2\r\n' == b''.join(d for d, _ in s._buffer) - assert b'data' == bytes(stream._buffer) - assert stream._parser is None - assert s._eof - - -def test_feed_parser(loop, lines_parser): - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(lines_parser) - assert s._allow_pause - - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - assert b'data' == bytes(stream._buffer) - - stream.feed_eof() - assert ([(bytearray(b'line1\r\n'), 7), (bytearray(b'line2\r\n'), 7)] == - list(s._buffer)) - assert b'data' == bytes(stream._buffer) - assert s.is_eof() - - -def test_feed_parser_exc(loop): - def p(out, buf): - yield # read chunk - raise ValueError() - - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(p) - - stream.feed_data(b'line1') - assert isinstance(s.exception(), ValueError) - assert b'' == bytes(stream._buffer) - - -def test_feed_parser_stop(loop): - def p(out, buf): - yield # chunk - - stream = parsers.StreamParser(loop=loop) - stream.set_parser(p) - - stream.feed_data(b'line1') - assert stream._parser is None - assert b'' == bytes(stream._buffer) - - -def test_feed_eof_exc(loop): - def p(out, buf): - try: - while True: - yield # read chunk - except parsers.EofStream: - raise ValueError() - - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(p) - - stream.feed_data(b'line1') - assert s.exception() is None - - stream.feed_eof() - assert isinstance(s.exception(), ValueError) - - -def test_feed_eof_stop(loop): - def p(out, buf): - try: - while True: - yield # read chunk - except parsers.EofStream: - out.feed_eof() - - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(p) - - stream.feed_data(b'line1') - stream.feed_eof() - assert s._eof - - -def test_feed_eof_unhandled_eof(loop): - def p(out, buf): - while True: - yield # read chunk - - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(p) - - stream.feed_data(b'line1') - stream.feed_eof() - assert not s.is_eof() - assert isinstance(s.exception(), RuntimeError) - - -def test_feed_parser2(loop, lines_parser): - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(lines_parser) - - stream.feed_data(b'line1\r\nline2\r\n') - stream.feed_eof() - assert ([(bytearray(b'line1\r\n'), 7), (bytearray(b'line2\r\n'), 7)] == - list(s._buffer)) - assert b'' == bytes(stream._buffer) - assert s._eof - - -def test_unset_parser_eof_exc(loop): - def p(out, buf): - try: - while True: - yield # read chunk - except parsers.EofStream: - raise ValueError() - - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(p) - - stream.feed_data(b'line1') - stream.unset_parser() - assert isinstance(s.exception(), ValueError) - assert stream._parser is None - - -def test_unset_parser_eof_unhandled_eof(loop): - def p(out, buf): - while True: - yield # read chunk - - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(p) - - stream.feed_data(b'line1') - stream.unset_parser() - assert isinstance(s.exception(), RuntimeError) - assert not s.is_eof() - - -def test_unset_parser_stop(loop): - def p(out, buf): - try: - while True: - yield # read chunk - except parsers.EofStream: - out.feed_eof() - - stream = parsers.StreamParser(loop=loop) - s = stream.set_parser(p) - - stream.feed_data(b'line1') - stream.unset_parser() - assert s._eof - - -def test_eof_exc(loop): - def p(out, buf): - while True: - yield # read chunk - - class CustomEofErr(Exception): - pass - - stream = parsers.StreamParser(eof_exc_class=CustomEofErr, loop=loop) - s = stream.set_parser(p) - - stream.feed_eof() - assert isinstance(s.exception(), CustomEofErr) diff --git a/tests/test_stream_protocol.py b/tests/test_stream_protocol.py deleted file mode 100644 index 7f9a7d0ce1b..00000000000 --- a/tests/test_stream_protocol.py +++ /dev/null @@ -1,40 +0,0 @@ -from unittest import mock - -from aiohttp import parsers - - -def test_connection_made(loop): - tr = mock.Mock() - - proto = parsers.StreamProtocol(loop=loop) - assert proto.transport is None - - proto.connection_made(tr) - assert proto.transport is tr - - -def test_connection_lost(loop): - proto = parsers.StreamProtocol(loop=loop) - proto.connection_made(mock.Mock()) - proto.connection_lost(None) - assert proto.transport is None - assert proto.writer is None - assert proto.reader._eof - - -def test_connection_lost_exc(loop): - proto = parsers.StreamProtocol(loop=loop) - proto.connection_made(mock.Mock()) - - exc = ValueError() - proto.connection_lost(exc) - assert proto.reader.exception() is exc - - -def test_data_received(loop): - proto = parsers.StreamProtocol(loop=loop) - proto.connection_made(mock.Mock()) - proto.reader = mock.Mock() - - proto.data_received(b'data') - proto.reader.feed_data.assert_called_with(b'data') diff --git a/tests/test_stream_writer.py b/tests/test_stream_writer.py index 5c93b0f9d51..cc464598eee 100644 --- a/tests/test_stream_writer.py +++ b/tests/test_stream_writer.py @@ -3,7 +3,7 @@ import pytest -from aiohttp.parsers import CORK, StreamWriter +from aiohttp.streams import CORK, StreamWriter has_ipv6 = socket.has_ipv6 if has_ipv6: @@ -23,8 +23,7 @@ def test_nodelay_default(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) assert not writer.tcp_nodelay assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @@ -34,8 +33,7 @@ def test_set_nodelay_no_change(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(False) assert not writer.tcp_nodelay assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @@ -49,8 +47,7 @@ def test_set_nodelay_exception(loop): s.setsockopt.side_effect = OSError transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(True) assert not writer.tcp_nodelay @@ -60,8 +57,7 @@ def test_set_nodelay_enable(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(True) assert writer.tcp_nodelay assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @@ -72,8 +68,7 @@ def test_set_nodelay_enable_and_disable(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(True) writer.set_tcp_nodelay(False) assert not writer.tcp_nodelay @@ -86,8 +81,7 @@ def test_set_nodelay_enable_ipv6(loop): s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(True) assert writer.tcp_nodelay assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @@ -101,8 +95,7 @@ def test_set_nodelay_enable_unix(loop): s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(True) assert not writer.tcp_nodelay @@ -111,8 +104,7 @@ def test_set_nodelay_enable_no_socket(loop): transport = mock.Mock() transport.get_extra_info.return_value = None proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(True) assert not writer.tcp_nodelay assert writer._socket is None @@ -126,8 +118,7 @@ def test_cork_default(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) assert not writer.tcp_cork assert not s.getsockopt(socket.IPPROTO_TCP, CORK) @@ -138,8 +129,7 @@ def test_set_cork_no_change(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(False) assert not writer.tcp_cork assert not s.getsockopt(socket.IPPROTO_TCP, CORK) @@ -151,8 +141,7 @@ def test_set_cork_enable(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(True) assert writer.tcp_cork assert s.getsockopt(socket.IPPROTO_TCP, CORK) @@ -164,8 +153,7 @@ def test_set_cork_enable_and_disable(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(True) writer.set_tcp_cork(False) assert not writer.tcp_cork @@ -179,8 +167,7 @@ def test_set_cork_enable_ipv6(loop): s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(True) assert writer.tcp_cork assert s.getsockopt(socket.IPPROTO_TCP, CORK) @@ -194,8 +181,7 @@ def test_set_cork_enable_unix(loop): s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(True) assert not writer.tcp_cork @@ -205,8 +191,7 @@ def test_set_cork_enable_no_socket(loop): transport = mock.Mock() transport.get_extra_info.return_value = None proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(True) assert not writer.tcp_cork assert writer._socket is None @@ -219,8 +204,7 @@ def test_set_cork_exception(loop): s.family = (socket.AF_INET,) s.setsockopt.side_effect = OSError proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(True) assert not writer.tcp_cork @@ -233,8 +217,7 @@ def test_set_enabling_cork_disables_nodelay(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_nodelay(True) writer.set_tcp_cork(True) assert not writer.tcp_nodelay @@ -249,8 +232,7 @@ def test_set_enabling_nodelay_disables_cork(loop): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) transport.get_extra_info.return_value = s proto = mock.Mock() - reader = mock.Mock() - writer = StreamWriter(transport, proto, reader, loop) + writer = StreamWriter(proto, transport, loop) writer.set_tcp_cork(True) writer.set_tcp_nodelay(True) assert writer.tcp_nodelay diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 10c34d0829a..6a7dcdbcd3f 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -28,7 +28,7 @@ def acquire(cb): def append(data): buf.extend(data) - writer.write.side_effect = append + writer.transport.write.side_effect = append app = mock.Mock() app._debug = False app.on_response_prepare = signals.Signal(app) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 50ff305ce0f..20a882bcb8c 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -38,6 +38,7 @@ def write(chunk): transport.acquire.side_effect = acquire transport.write.side_effect = write + transport.transport.write.side_effect = write transport.drain.return_value = () return (transport, buf) diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index d7f61c6480a..00763b191bc 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -7,7 +7,7 @@ import aiohttp from aiohttp import WebSocketError, WSCloseCode, WSMessage, WSMsgType, _ws_impl from aiohttp._ws_impl import (PACK_CLOSE_CODE, PACK_LEN1, PACK_LEN2, PACK_LEN3, - WebSocketParser, _websocket_mask, parse_frame) + WebSocketReader, _websocket_mask) def build_frame(message, opcode, use_mask=False, noheader=False): @@ -52,282 +52,206 @@ def build_close_frame(code=1000, message=b'', noheader=False): opcode=WSMsgType.CLOSE, noheader=noheader) -@pytest.fixture() -def buf(): - return aiohttp.ParserBuffer() - - @pytest.fixture() def out(loop): return aiohttp.DataQueue(loop=loop) @pytest.fixture() -def parser(buf, out): - return WebSocketParser(out, buf) +def parser(out): + return WebSocketReader(out) -def test_parse_frame(buf): - p = parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 0b00000001)) - try: - p.send(b'1') - except StopIteration as exc: - fin, opcode, payload = exc.value +def test_parse_frame(parser): + parser.parse_frame(struct.pack('!BB', 0b00000001, 0b00000001)) + res = parser.parse_frame(b'1') + fin, opcode, payload = res[0] assert (0, 1, b'1') == (fin, opcode, payload) -def test_parse_frame_length0(buf): - p = parse_frame(buf) - next(p) - try: - p.send(struct.pack('!BB', 0b00000001, 0b00000000)) - except StopIteration as exc: - fin, opcode, payload = exc.value +def test_parse_frame_length0(parser): + fin, opcode, payload = parser.parse_frame( + struct.pack('!BB', 0b00000001, 0b00000000))[0] assert (0, 1, b'') == (fin, opcode, payload) -def test_parse_frame_length2(buf): - p = parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 126)) - p.send(struct.pack('!H', 4)) - try: - p.send(b'1234') - except StopIteration as exc: - fin, opcode, payload = exc.value +def test_parse_frame_length2(parser): + parser.parse_frame(struct.pack('!BB', 0b00000001, 126)) + parser.parse_frame(struct.pack('!H', 4)) + res = parser.parse_frame(b'1234') + fin, opcode, payload = res[0] assert (0, 1, b'1234') == (fin, opcode, payload) -def test_parse_frame_length4(buf): - p = parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 127)) - p.send(struct.pack('!Q', 4)) - try: - p.send(b'1234') - except StopIteration as exc: - fin, opcode, payload = exc.value +def test_parse_frame_length4(parser): + parser.parse_frame(struct.pack('!BB', 0b00000001, 127)) + parser.parse_frame(struct.pack('!Q', 4)) + fin, opcode, payload = parser.parse_frame(b'1234')[0] assert (0, 1, b'1234') == (fin, opcode, payload) -def test_parse_frame_mask(buf): - p = parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 0b10000001)) - p.send(b'0001') - try: - p.send(b'1') - except StopIteration as exc: - fin, opcode, payload = exc.value +def test_parse_frame_mask(parser): + parser.parse_frame(struct.pack('!BB', 0b00000001, 0b10000001)) + parser.parse_frame(b'0001') + fin, opcode, payload = parser.parse_frame(b'1')[0] assert (0, 1, b'\x01') == (fin, opcode, payload) -def test_parse_frame_header_reversed_bits(buf): - p = parse_frame(buf) - next(p) +def test_parse_frame_header_reversed_bits(out, parser): with pytest.raises(WebSocketError): - p.send(struct.pack('!BB', 0b01100000, 0b00000000)) + parser.parse_frame(struct.pack('!BB', 0b01100000, 0b00000000)) + raise out.exception() -def test_parse_frame_header_control_frame(buf): - p = parse_frame(buf) - next(p) +def test_parse_frame_header_control_frame(out, parser): with pytest.raises(WebSocketError): - p.send(struct.pack('!BB', 0b00001000, 0b00000000)) + parser.parse_frame(struct.pack('!BB', 0b00001000, 0b00000000)) + raise out.exception() -def test_parse_frame_header_continuation(buf): - p = parse_frame(buf) - next(p) +def test_parse_frame_header_continuation(out, parser): with pytest.raises(WebSocketError): - p.send(struct.pack('!BB', 0b00000000, 0b00000000)) + parser._frame_fin = True + parser.parse_frame(struct.pack('!BB', 0b00000000, 0b00000000)) + raise out.exception() -def test_parse_frame_header_new_data_err(buf): - p = parse_frame(buf) - next(p) +def _test_parse_frame_header_new_data_err(out, parser): with pytest.raises(WebSocketError): - p.send(struct.pack('!BB', 0b000000000, 0b00000000)) + parser.parse_frame(struct.pack('!BB', 0b000000000, 0b00000000)) + raise out.exception() -def test_parse_frame_header_payload_size(buf): - p = parse_frame(buf) - next(p) +def test_parse_frame_header_payload_size(out, parser): with pytest.raises(WebSocketError): - p.send(struct.pack('!BB', 0b10001000, 0b01111110)) + parser.parse_frame(struct.pack('!BB', 0b10001000, 0b01111110)) + raise out.exception() def test_ping_frame(out, parser): - def parse_frame(buf): - yield - return (1, WSMsgType.PING, b'data') - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [(1, WSMsgType.PING, b'data')] + + parser.feed_data(b'') res = out._buffer[0] assert res == ((WSMsgType.PING, b'data', ''), 4) def test_pong_frame(out, parser): - def parse_frame(buf): - yield - return (1, WSMsgType.PONG, b'data') - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [(1, WSMsgType.PONG, b'data')] + + parser.feed_data(b'') res = out._buffer[0] assert res == ((WSMsgType.PONG, b'data', ''), 4) def test_close_frame(out, parser): - def parse_frame(buf): - yield - return (1, WSMsgType.CLOSE, b'') - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'')] + parser.feed_data(b'') res = out._buffer[0] assert res == ((WSMsgType.CLOSE, 0, ''), 0) def test_close_frame_info(out, parser): - def parse_frame(buf): - yield - return (1, WSMsgType.CLOSE, b'0112345') - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'0112345')] + + parser.feed_data(b'') res = out._buffer[0] assert res == (WSMessage(WSMsgType.CLOSE, 12337, '12345'), 0) def test_close_frame_invalid(out, parser): - def parse_frame(buf): - yield - return (1, WSMsgType.CLOSE, b'1') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'1')] + parser.feed_data(b'') - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - with pytest.raises(WebSocketError) as ctx: - next(parser) + assert isinstance(out.exception(), WebSocketError) + assert out.exception().code == WSCloseCode.PROTOCOL_ERROR - assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR +def test_close_frame_invalid_2(out, parser): + data = build_close_frame(code=1) -def test_close_frame_invalid_2(buf, parser): - buf.extend(build_close_frame(code=1)) with pytest.raises(WebSocketError) as ctx: - next(parser) + parser._feed_data(data) assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR -def test_close_frame_unicode_err(buf, parser): - buf.extend(build_close_frame( - code=1000, message=b'\xf4\x90\x80\x80')) +def test_close_frame_unicode_err(parser): + data = build_close_frame( + code=1000, message=b'\xf4\x90\x80\x80') + with pytest.raises(WebSocketError) as ctx: - next(parser) + parser._feed_data(data) assert ctx.value.code == WSCloseCode.INVALID_TEXT def test_unknown_frame(out, parser): - def parse_frame(buf): - yield - return (1, WSMsgType.CONTINUATION, b'') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [(1, WSMsgType.CONTINUATION, b'')] - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - - with pytest.raises(WebSocketError): - parser.send(b'') + with pytest.raises(WebSocketError): + parser.feed_data(b'') + raise out.exception() -def test_simple_text(buf, out, parser): - buf.extend(build_frame(b'text', WSMsgType.TEXT)) - next(parser) - parser.send(b'') +def test_simple_text(out, parser): + data = build_frame(b'text', WSMsgType.TEXT) + parser._feed_data(data) res = out._buffer[0] assert res == ((WSMsgType.TEXT, 'text', ''), 4) -def test_simple_text_unicode_err(buf, parser): - buf.extend( - build_frame(b'\xf4\x90\x80\x80', WSMsgType.TEXT)) +def test_simple_text_unicode_err(parser): + data = build_frame(b'\xf4\x90\x80\x80', WSMsgType.TEXT) + with pytest.raises(WebSocketError) as ctx: - next(parser) + parser._feed_data(data) assert ctx.value.code == WSCloseCode.INVALID_TEXT def test_simple_binary(out, parser): - def parse_frame(buf): - yield - return (1, WSMsgType.BINARY, b'binary') - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [(1, WSMsgType.BINARY, b'binary')] + + parser.feed_data(b'') res = out._buffer[0] assert res == ((WSMsgType.BINARY, b'binary', ''), 6) def test_continuation(out, parser): - cur = 0 - - def parse_frame(buf, cont=False): - nonlocal cur - yield - if cur == 0: - cur = 1 - return (0, WSMsgType.TEXT, b'line1') - else: - return (1, WSMsgType.CONTINUATION, b'line2') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ + (0, WSMsgType.TEXT, b'line1'), + (1, WSMsgType.CONTINUATION, b'line2')] + + parser._feed_data(b'') - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - parser.send(b'') res = out._buffer[0] assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10) def test_continuation_with_ping(out, parser): - frames = [ + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ (0, WSMsgType.TEXT, b'line1'), (0, WSMsgType.PING, b''), (1, WSMsgType.CONTINUATION, b'line2'), ] - def parse_frame(buf, cont=False): - yield - return frames.pop(0) - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - parser.send(b'') - parser.send(b'') + parser.feed_data(b'') res = out._buffer[0] assert res == (WSMessage(WSMsgType.PING, b'', ''), 0) res = out._buffer[1] @@ -335,129 +259,81 @@ def parse_frame(buf, cont=False): def test_continuation_err(out, parser): - cur = 0 - - def parse_frame(buf, cont=False): - nonlocal cur - yield - if cur == 0: - cur = 1 - return (0, WSMsgType.TEXT, b'line1') - else: - return (1, WSMsgType.TEXT, b'line2') + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ + (0, WSMsgType.TEXT, b'line1'), + (1, WSMsgType.TEXT, b'line2')] - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - with pytest.raises(WebSocketError): - parser.send(b'') + with pytest.raises(WebSocketError): + parser._feed_data(b'') def test_continuation_with_close(out, parser): - frames = [ + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ (0, WSMsgType.TEXT, b'line1'), (0, WSMsgType.CLOSE, build_close_frame(1002, b'test', noheader=True)), (1, WSMsgType.CONTINUATION, b'line2'), ] - def parse_frame(buf, cont=False): - yield - return frames.pop(0) - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - parser.send(b'') - parser.send(b'') - res = out._buffer[0] + parser.feed_data(b'') + res = out._buffer[0] assert res, (WSMessage(WSMsgType.CLOSE, 1002, 'test'), 0) res = out._buffer[1] assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10) def test_continuation_with_close_unicode_err(out, parser): - frames = [ + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ (0, WSMsgType.TEXT, b'line1'), (0, WSMsgType.CLOSE, build_close_frame(1000, b'\xf4\x90\x80\x80', noheader=True)), (1, WSMsgType.CONTINUATION, b'line2')] - def parse_frame(buf, cont=False): - yield - return frames.pop(0) - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - with pytest.raises(WebSocketError) as ctx: - parser.send(b'') + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(b'') assert ctx.value.code == WSCloseCode.INVALID_TEXT def test_continuation_with_close_bad_code(out, parser): - frames = [ + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ (0, WSMsgType.TEXT, b'line1'), (0, WSMsgType.CLOSE, build_close_frame(1, b'test', noheader=True)), (1, WSMsgType.CONTINUATION, b'line2')] - def parse_frame(buf, cont=False): - yield - return frames.pop(0) - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - with pytest.raises(WebSocketError) as ctx: - parser.send(b'') + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(b'') - assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR + assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR def test_continuation_with_close_bad_payload(out, parser): - frames = [ + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ (0, WSMsgType.TEXT, b'line1'), (0, WSMsgType.CLOSE, b'1'), (1, WSMsgType.CONTINUATION, b'line2')] - def parse_frame(buf, cont=False): - yield - return frames.pop(0) - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - with pytest.raises(WebSocketError) as ctx: - parser.send(b'') + with pytest.raises(WebSocketError) as ctx: + parser._feed_data(b'') - assert ctx.value.code, WSCloseCode.PROTOCOL_ERROR + assert ctx.value.code, WSCloseCode.PROTOCOL_ERROR def test_continuation_with_close_empty(out, parser): - frames = [ + parser.parse_frame = mock.Mock() + parser.parse_frame.return_value = [ (0, WSMsgType.TEXT, b'line1'), (0, WSMsgType.CLOSE, b''), (1, WSMsgType.CONTINUATION, b'line2'), ] - def parse_frame(buf, cont=False): - yield - return frames.pop(0) - - with mock.patch('aiohttp._ws_impl.parse_frame') as m_parse_frame: - m_parse_frame.side_effect = parse_frame - next(parser) - parser.send(b'') - parser.send(b'') - parser.send(b'') - + parser.feed_data(b'') res = out._buffer[0] assert res, (WSMessage(WSMsgType.CLOSE, 0, ''), 0) res = out._buffer[1] diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 6cb91662023..43b3330416a 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -7,62 +7,62 @@ @pytest.fixture -def transport(): +def stream(): return mock.Mock() @pytest.fixture -def writer(transport): - return WebSocketWriter(transport, use_mask=False) +def writer(stream): + return WebSocketWriter(stream, use_mask=False) -def test_pong(transport, writer): +def test_pong(stream, writer): writer.pong() - transport.write.assert_called_with(b'\x8a\x00') + stream.transport.write.assert_called_with(b'\x8a\x00') -def test_ping(transport, writer): +def test_ping(stream, writer): writer.ping() - transport.write.assert_called_with(b'\x89\x00') + stream.transport.write.assert_called_with(b'\x89\x00') -def test_send_text(transport, writer): +def test_send_text(stream, writer): writer.send(b'text') - transport.write.assert_called_with(b'\x81\x04text') + stream.transport.write.assert_called_with(b'\x81\x04text') -def test_send_binary(transport, writer): +def test_send_binary(stream, writer): writer.send('binary', True) - transport.write.assert_called_with(b'\x82\x06binary') + stream.transport.write.assert_called_with(b'\x82\x06binary') -def test_send_binary_long(transport, writer): +def test_send_binary_long(stream, writer): writer.send(b'b' * 127, True) - assert transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') + assert stream.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') -def test_send_binary_very_long(transport, writer): +def test_send_binary_very_long(stream, writer): writer.send(b'b' * 65537, True) - assert (transport.write.call_args_list[0][0][0] == + assert (stream.transport.write.call_args_list[0][0][0] == b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01') - assert transport.write.call_args_list[1][0][0] == b'b' * 65537 + assert stream.transport.write.call_args_list[1][0][0] == b'b' * 65537 -def test_close(transport, writer): +def test_close(stream, writer): writer.close(1001, 'msg') - transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') writer.close(1001, b'msg') - transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + stream.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') # Test that Service Restart close code is also supported writer.close(1012, b'msg') - transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') + stream.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') -def test_send_text_masked(transport, writer): - writer = WebSocketWriter(transport, +def test_send_text_masked(stream, writer): + writer = WebSocketWriter(stream, use_mask=True, random=random.Random(123)) writer.send(b'text') - transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') + stream.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index a8b74c999bd..d6c0621e36a 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -172,7 +172,7 @@ def wsgi_app(env, start): srv.handle_request(self.message, self.payload)) content = b''.join( - [c[1][0] for c in self.writer.write.mock_calls]) + [c[1][0] for c in self.writer.transport.write.mock_calls]) self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) self.assertTrue(content.endswith(b'data')) @@ -195,7 +195,7 @@ def wsgi_app(env, start): srv.handle_request(self.message, self.payload)) content = b''.join( - [c[1][0] for c in self.writer.write.mock_calls]) + [c[1][0] for c in self.writer.transport.write.mock_calls]) self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) self.assertFalse(srv._keepalive) @@ -212,7 +212,7 @@ def wsgi_app(env, start): srv.handle_request(self.message, self.payload)) content = b''.join( - [c[1][0] for c in self.writer.write.mock_calls]) + [c[1][0] for c in self.writer.transport.write.mock_calls]) self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) self.assertTrue(content.endswith(b'data')) @@ -236,7 +236,7 @@ def wsgi_app(env, start): srv.handle_request(self.message, self.payload)) content = b''.join( - [c[1][0] for c in self.writer.write.mock_calls]) + [c[1][0] for c in self.writer.transport.write.mock_calls]) self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) self.assertTrue(srv._keepalive) @@ -253,7 +253,7 @@ def wsgi_app(env, start): srv.handle_request(self.message, self.payload)) content = b''.join( - [c[1][0] for c in self.writer.write.mock_calls]) + [c[1][0] for c in self.writer.transport.write.mock_calls]) self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) self.assertTrue(content.endswith(b'data'))