From 89abb8983f0a8e1de231ee8f3ddf223bf11433d0 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 15 Mar 2018 18:38:22 +0200 Subject: [PATCH] Websockets refactoring (#2836) --- CHANGES/2836.feature | 2 + aiohttp/client_ws.py | 9 +++-- aiohttp/http_websocket.py | 26 ++++++------- aiohttp/web_ws.py | 9 +++-- tests/test_client_ws.py | 13 +++++-- tests/test_client_ws_functional.py | 2 +- tests/test_websocket_writer.py | 61 ++++++++++++++++-------------- 7 files changed, 69 insertions(+), 53 deletions(-) create mode 100644 CHANGES/2836.feature diff --git a/CHANGES/2836.feature b/CHANGES/2836.feature new file mode 100644 index 00000000000..f9e3c7826cc --- /dev/null +++ b/CHANGES/2836.feature @@ -0,0 +1,2 @@ +Websockets refactoring, all websocket writer methods are converted +into coroutines. \ No newline at end of file diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 3968e5759c5..7fe8182483c 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -33,7 +33,7 @@ def __init__(self, reader, writer, protocol, self._heartbeat = heartbeat self._heartbeat_cb = None if heartbeat is not None: - self._pong_heartbeat = heartbeat/2.0 + self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb = None self._loop = loop self._waiting = None @@ -61,7 +61,10 @@ def _reset_heartbeat(self): def _send_heartbeat(self): if self._heartbeat is not None and not self._closed: - self._writer.ping() + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + self._loop.create_task(self._writer.ping()) if self._pong_response_cb is not None: self._pong_response_cb.cancel() @@ -137,7 +140,7 @@ async def close(self, *, code=1000, message=b''): self._cancel_heartbeat() self._closed = True try: - self._writer.close(code, message) + await self._writer.close(code, message) except asyncio.CancelledError: self._close_code = 1006 self._response.close() diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index a5ca686f64e..155a945e76d 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -9,7 +9,7 @@ from enum import IntEnum from struct import Struct -from .helpers import NO_EXTENSIONS, noop +from .helpers import NO_EXTENSIONS from .log import ws_logger @@ -527,7 +527,7 @@ def __init__(self, protocol, transport, *, self._output_size = 0 self._compressobj = None - def _send_frame(self, message, opcode, compress=None): + async def _send_frame(self, message, opcode, compress=None): """Send a frame over the websocket with message as its payload.""" if self._closing: ws_logger.warning('websocket connection is closing.') @@ -585,37 +585,35 @@ def _send_frame(self, message, opcode, compress=None): if self._output_size > self._limit: self._output_size = 0 - return self.protocol._drain_helper() + await self.protocol._drain_helper() - return noop() - - def pong(self, message=b''): + async def pong(self, message=b''): """Send pong message.""" if isinstance(message, str): message = message.encode('utf-8') - return self._send_frame(message, WSMsgType.PONG) + return await self._send_frame(message, WSMsgType.PONG) - def ping(self, message=b''): + async def ping(self, message=b''): """Send ping message.""" if isinstance(message, str): message = message.encode('utf-8') - return self._send_frame(message, WSMsgType.PING) + return await self._send_frame(message, WSMsgType.PING) - def send(self, message, binary=False, compress=None): + async def send(self, message, binary=False, compress=None): """Send a frame over the websocket with message as its payload.""" if isinstance(message, str): message = message.encode('utf-8') if binary: - return self._send_frame(message, WSMsgType.BINARY, compress) + return await self._send_frame(message, WSMsgType.BINARY, compress) else: - return self._send_frame(message, WSMsgType.TEXT, compress) + return await self._send_frame(message, WSMsgType.TEXT, compress) - def close(self, code=1000, message=b''): + async def close(self, code=1000, message=b''): """Close the websocket, sending the specified code and message.""" if isinstance(message, str): message = message.encode('utf-8') try: - return self._send_frame( + return await self._send_frame( PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE) finally: self._closing = True diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 3753dcb6942..d9c4e3bce8e 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -58,7 +58,7 @@ def __init__(self, *, self._heartbeat = heartbeat self._heartbeat_cb = None if heartbeat is not None: - self._pong_heartbeat = heartbeat/2.0 + self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb = None self._compress = compress @@ -80,7 +80,10 @@ def _reset_heartbeat(self): def _send_heartbeat(self): if self._heartbeat is not None and not self._closed: - self._writer.ping() + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + self._loop.create_task(self._writer.ping()) if self._pong_response_cb is not None: self._pong_response_cb.cancel() @@ -286,7 +289,7 @@ async def close(self, *, code=1000, message=b''): if not self._closed: self._closed = True try: - self._writer.close(code, message) + await self._writer.close(code, message) await self._payload_writer.drain() except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = 1006 diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index cd093cfc7f1..a27f1c46f08 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -242,7 +242,9 @@ async def test_close(loop, ws_key, key_data): m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - writer = WebSocketWriter.return_value = mock.Mock() + writer = mock.Mock() + WebSocketWriter.return_value = writer + writer.close = make_mocked_coro() session = aiohttp.ClientSession(loop=loop) resp = await session.ws_connect( @@ -280,7 +282,9 @@ async def test_close_exc(loop, ws_key, key_data): m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) - WebSocketWriter.return_value = mock.Mock() + writer = mock.Mock() + WebSocketWriter.return_value = writer + writer.close = make_mocked_coro() session = aiohttp.ClientSession(loop=loop) resp = await session.ws_connect('http://test.org') @@ -400,7 +404,10 @@ async def test_reader_read_exception(ws_key, key_data, loop): m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(hresp) - WebSocketWriter.return_value = mock.Mock() + + writer = mock.Mock() + WebSocketWriter.return_value = writer + writer.close = make_mocked_coro() session = aiohttp.ClientSession(loop=loop) resp = await session.ws_connect('http://test.org') diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 255d8423385..aee2af24364 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -544,7 +544,7 @@ async def handler(request): client = await aiohttp_client(app) resp = await client.ws_connect('/', heartbeat=0.01) - + await asyncio.sleep(0.1) await resp.receive() await resp.close() diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index af30b1e3910..946cf7ab267 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -4,11 +4,14 @@ import pytest from aiohttp.http import WebSocketWriter +from aiohttp.test_utils import make_mocked_coro @pytest.fixture def protocol(): - return mock.Mock() + ret = mock.Mock() + ret._drain_helper = make_mocked_coro() + return ret @pytest.fixture @@ -21,83 +24,83 @@ def writer(protocol, transport): return WebSocketWriter(protocol, transport, use_mask=False) -def test_pong(writer): - writer.pong() +async def test_pong(writer): + await writer.pong() writer.transport.write.assert_called_with(b'\x8a\x00') -def test_ping(writer): - writer.ping() +async def test_ping(writer): + await writer.ping() writer.transport.write.assert_called_with(b'\x89\x00') -def test_send_text(writer): - writer.send(b'text') +async def test_send_text(writer): + await writer.send(b'text') writer.transport.write.assert_called_with(b'\x81\x04text') -def test_send_binary(writer): - writer.send('binary', True) +async def test_send_binary(writer): + await writer.send('binary', True) writer.transport.write.assert_called_with(b'\x82\x06binary') -def test_send_binary_long(writer): - writer.send(b'b' * 127, True) +async def test_send_binary_long(writer): + await writer.send(b'b' * 127, True) assert writer.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb') -def test_send_binary_very_long(writer): - writer.send(b'b' * 65537, True) +async def test_send_binary_very_long(writer): + await writer.send(b'b' * 65537, True) assert (writer.transport.write.call_args_list[0][0][0] == b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01') assert writer.transport.write.call_args_list[1][0][0] == b'b' * 65537 -def test_close(writer): - writer.close(1001, 'msg') +async def test_close(writer): + await writer.close(1001, 'msg') writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') - writer.close(1001, b'msg') + await writer.close(1001, b'msg') writer.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') # Test that Service Restart close code is also supported - writer.close(1012, b'msg') + await writer.close(1012, b'msg') writer.transport.write.assert_called_with(b'\x88\x05\x03\xf4msg') -def test_send_text_masked(protocol, transport): +async def test_send_text_masked(protocol, transport): writer = WebSocketWriter(protocol, transport, use_mask=True, random=random.Random(123)) - writer.send(b'text') + await writer.send(b'text') writer.transport.write.assert_called_with(b'\x81\x84\rg\xb3fy\x02\xcb\x12') -def test_send_compress_text(protocol, transport): +async def test_send_compress_text(protocol, transport): writer = WebSocketWriter(protocol, transport, compress=15) - writer.send(b'text') + await writer.send(b'text') writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') - writer.send(b'text') + await writer.send(b'text') writer.transport.write.assert_called_with(b'\xc1\x05*\x01b\x00\x00') -def test_send_compress_text_notakeover(protocol, transport): +async def test_send_compress_text_notakeover(protocol, transport): writer = WebSocketWriter(protocol, transport, compress=15, notakeover=True) - writer.send(b'text') + await writer.send(b'text') writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') - writer.send(b'text') + await writer.send(b'text') writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') -def test_send_compress_text_per_message(protocol, transport): +async def test_send_compress_text_per_message(protocol, transport): writer = WebSocketWriter(protocol, transport) - writer.send(b'text', compress=15) + await writer.send(b'text', compress=15) writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00') - writer.send(b'text') + await writer.send(b'text') writer.transport.write.assert_called_with(b'\x81\x04text') - writer.send(b'text', compress=15) + await writer.send(b'text', compress=15) writer.transport.write.assert_called_with(b'\xc1\x06*I\xad(\x01\x00')