From 654a8060bdb8373d32b72b4b2dd62f48a2fb0f1d Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Thu, 9 Jul 2015 18:15:40 +0100 Subject: [PATCH 01/14] Initial signal implementation. Tests and documentation to follow. --- aiohttp/signals.py | 29 +++++++++++++++++++++++++++++ aiohttp/web.py | 3 +++ aiohttp/web_reqrep.py | 3 +++ 3 files changed, 35 insertions(+) create mode 100644 aiohttp/signals.py diff --git a/aiohttp/signals.py b/aiohttp/signals.py new file mode 100644 index 00000000000..181bdcbe82e --- /dev/null +++ b/aiohttp/signals.py @@ -0,0 +1,29 @@ +import asyncio + +class Signal(object): + def __init__(self, parameters): + self._parameters = frozenset(parameters) + self._receivers = set() + + def connect(self, receiver): + # Check that the callback can be called with the given parameter names + signature(receiver).bind(**{p: None for p in self._parameters}) + self._receivers.add(receiver) + + def disconnect(self, receiver): + self._receivers.remove(receiver) + + def send(self, **kwargs): + for receiver in self._receivers: + receiver(**kwargs) + +class AsyncSignal(Signal): + def connect(self, receiver): + assert asyncio.iscoroutinefunction(receiver), receiver + super().connect(receiver) + + @asyncio.coroutine + def send(self, **kwargs): + for receiver in self._receivers: + yield from receiver(**kwargs) + diff --git a/aiohttp/web.py b/aiohttp/web.py index cc86ff9230e..811b10e04a6 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -7,6 +7,7 @@ from .web_urldispatcher import * # noqa from .web_ws import * # noqa from .protocol import HttpVersion # noqa +from .signal import Signal __all__ = (web_reqrep.__all__ + web_exceptions.__all__ + @@ -195,6 +196,8 @@ def __init__(self, *, logger=web_logger, loop=None, assert asyncio.iscoroutinefunction(factory), factory self._middlewares = list(middlewares) + self.on_response_start = Signal({'request', 'response'}) + @property def router(self): return self._router diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 981598fc7cd..25c69da2018 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -622,6 +622,9 @@ def start(coding): return def start(self, request): + request.app.on_response_start.send(request=request, + response=self) + resp_impl = self._start_pre_check(request) if resp_impl is not None: return resp_impl From ac7536199f4bda693a72b2992eb27142e9f3176e Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 20 Jul 2015 11:07:43 +0100 Subject: [PATCH 02/14] Wrap iscoroutinefunction check in 'if __debug__', so people can optimise it out. --- aiohttp/signals.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index 181bdcbe82e..a6008531f5e 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -1,3 +1,5 @@ +from inspect import signature + import asyncio class Signal(object): @@ -7,7 +9,8 @@ def __init__(self, parameters): def connect(self, receiver): # Check that the callback can be called with the given parameter names - signature(receiver).bind(**{p: None for p in self._parameters}) + if __debug__: + signature(receiver).bind(**{p: None for p in self._parameters}) self._receivers.add(receiver) def disconnect(self, receiver): From fefd2ed137c59830fb3f8872beec8cac8c8e5cc7 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 20 Jul 2015 11:08:14 +0100 Subject: [PATCH 03/14] Rename AsyncSignal to CoroutineSignal for clarity of purpose --- aiohttp/signals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index a6008531f5e..efbdee2ef06 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -20,7 +20,7 @@ def send(self, **kwargs): for receiver in self._receivers: receiver(**kwargs) -class AsyncSignal(Signal): +class CoroutineSignal(Signal): def connect(self, receiver): assert asyncio.iscoroutinefunction(receiver), receiver super().connect(receiver) From d45ff675efb7887679aa6f0530f772e01d35dd25 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 20 Jul 2015 14:22:14 +0100 Subject: [PATCH 04/14] Add base class for signals --- aiohttp/signals.py | 12 +++++++++++- aiohttp/web.py | 4 ++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index efbdee2ef06..8fbcd0f64fb 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -1,8 +1,9 @@ +import abc from inspect import signature import asyncio -class Signal(object): +class Signal(metaclass=abc.ABCMeta): def __init__(self, parameters): self._parameters = frozenset(parameters) self._receivers = set() @@ -16,6 +17,15 @@ def connect(self, receiver): def disconnect(self, receiver): self._receivers.remove(receiver) + @abc.abstractmethod + def send(self, **kwargs): + pass + +class FunctionSignal(Signal): + def connect(self, receiver): + assert not asyncio.iscoroutinefunction(receiver), receiver + super().connect(receiver) + def send(self, **kwargs): for receiver in self._receivers: receiver(**kwargs) diff --git a/aiohttp/web.py b/aiohttp/web.py index 811b10e04a6..9a9b75a1381 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -7,7 +7,7 @@ from .web_urldispatcher import * # noqa from .web_ws import * # noqa from .protocol import HttpVersion # noqa -from .signal import Signal +from .signals import FunctionSignal __all__ = (web_reqrep.__all__ + web_exceptions.__all__ + @@ -196,7 +196,7 @@ def __init__(self, *, logger=web_logger, loop=None, assert asyncio.iscoroutinefunction(factory), factory self._middlewares = list(middlewares) - self.on_response_start = Signal({'request', 'response'}) + self.on_response_start = FunctionSignal({'request', 'response'}) @property def router(self): From cf2966007f937671ab780c3b797cc7c8c75d5966 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 20 Jul 2015 14:22:45 +0100 Subject: [PATCH 05/14] Add signal tests --- tests/test_signals.py | 99 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 tests/test_signals.py diff --git a/tests/test_signals.py b/tests/test_signals.py new file mode 100644 index 00000000000..732908e30b0 --- /dev/null +++ b/tests/test_signals.py @@ -0,0 +1,99 @@ +import asyncio +import unittest +from unittest import mock +from aiohttp.multidict import CIMultiDict +from aiohttp.signals import Signal, FunctionSignal, CoroutineSignal +from aiohttp.web import Application +from aiohttp.web import Request, StreamResponse, Response +from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 +from aiohttp.protocol import RawRequestMessage + +class TestSignals(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def make_request(self, method, path, headers=CIMultiDict(), app=None): + message = RawRequestMessage(method, path, HttpVersion11, headers, + False, False) + return self.request_from_message(message, app) + + def request_from_message(self, message, app=None): + self.app = app if app is not None else mock.Mock() + self.payload = mock.Mock() + self.transport = mock.Mock() + self.reader = mock.Mock() + self.writer = mock.Mock() + req = Request(self.app, message, self.payload, + self.transport, self.reader, self.writer) + return req + + def test_callback_valid(self): + signal = FunctionSignal({'foo', 'bar'}) + + # All these are suitable + good_callbacks = [ + (lambda foo, bar: None), + (lambda *, foo, bar: None), + (lambda foo, bar, **kwargs: None), + (lambda foo, bar, baz=None: None), + (lambda baz=None, *, foo, bar: None), + (lambda foo=None, bar=None: None), + (lambda foo, bar=None, *, baz=None: None), + (lambda **kwargs: None), + ] + for callback in good_callbacks: + signal.connect(callback) + + def test_callback_invalid(self): + signal = FunctionSignal({'foo', 'bar'}) + + # All these are unsuitable + bad_callbacks = [ + (lambda foo: None), + (lambda foo, bar, baz: None), + ] + for callback in bad_callbacks: + with self.assertRaises(TypeError): + signal.connect(callback) + + def test_add_response_start_signal_handler(self): + callback = lambda request, response: None + app = Application(loop=self.loop) + app.on_response_start.connect(callback) + + def test_add_signal_handler_not_a_callable(self): + callback = True + app = Application(loop=self.loop) + with self.assertRaises(TypeError): + app.on_response_start.connect(callback) + + def test_function_signal_dispatch(self): + signal = CoroutineSignal({'foo', 'bar'}) + kwargs = {'foo': 1, 'bar': 2} + + callback_mock = mock.Mock() + callback = asyncio.coroutine(callback_mock) + + signal.connect(callback) + + self.loop.run_until_complete(signal.send(**kwargs)) + callback_mock.assert_called_once_with(**kwargs) + + def test_response_start(self): + callback = mock.Mock() + callback._is_coroutine = False + + app = Application(loop=self.loop) + app.on_response_start.connect(callback) + + request = self.make_request('GET', '/', app=app) + response = Response(body=b'') + response.start(request) + + callback.assert_called_once_with(request=request, + response=response) + From 98c418ad37549e0b7dcf2eb2f1f74fe0eeb497b0 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Tue, 21 Jul 2015 09:28:19 +0100 Subject: [PATCH 06/14] Documentation! --- aiohttp/signals.py | 45 ++++++++++++++++++++++++++++++++++++++++++ docs/api.rst | 8 ++++++++ docs/web_reference.rst | 10 +++++++++- 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index 8fbcd0f64fb..4e5e67f8efe 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -4,24 +4,64 @@ import asyncio class Signal(metaclass=abc.ABCMeta): + """ + Abstract base class for signals. + + To connect a callback to a signal, use the :meth:`callback` method. If you + wish to pass additional arguments to your callback, + use :meth:`functools.partial`. Signals can be disconnected again using + :meth:`disconnect`. Callbacks are executed in an arbitrary order. + + There are two declared concrete subclasses, :class:`FunctionSignal`, which + dispatches to plain function callbacks, and :class:`CoroutineSignal`, + which accepts coroutine functions as callbacks. + + Signals are fired using :meth:`send`, which takes named arguments. The + :meth:`send` method for :class:`CoroutineSignal` is itself a coroutine + function. + """ def __init__(self, parameters): self._parameters = frozenset(parameters) self._receivers = set() def connect(self, receiver): + """ + Connect a receiver. + + :param collections.abc.Callable receiver: A function to be called + whenever the signal is fired. + :raises TypeError: if ``receiver`` isn't a callable, or doesn't have + a call signature that supports the signals parameters. + """ # Check that the callback can be called with the given parameter names if __debug__: signature(receiver).bind(**{p: None for p in self._parameters}) self._receivers.add(receiver) def disconnect(self, receiver): + """ + Disconnect a receiver. + + :param collections.abc.Callable receiver: A function to no longer + be called whenever the signal is fired. + + :raises KeyError: if the receiver wasn't already registered. + """ self._receivers.remove(receiver) @abc.abstractmethod def send(self, **kwargs): + """ + Sends data to all registered receivers. + """ pass class FunctionSignal(Signal): + """ + A signal type that dispatches to plain functions. + + See :class:`Signal` for documentation. + """ def connect(self, receiver): assert not asyncio.iscoroutinefunction(receiver), receiver super().connect(receiver) @@ -31,6 +71,11 @@ def send(self, **kwargs): receiver(**kwargs) class CoroutineSignal(Signal): + """ + A signal type that dispatches to coroutine functions. + + See :class:`Signal` for documentation. + """ def connect(self, receiver): assert asyncio.iscoroutinefunction(receiver), receiver super().connect(receiver) diff --git a/docs/api.rst b/docs/api.rst index ed8183a721f..033565a33a0 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -49,6 +49,14 @@ aiohttp.protocol module :undoc-members: :show-inheritance: +aiohttp.signals module +---------------------- + +.. automodule:: aiohttp.signals + :members: + :undoc-members: + :show-inheritance: + aiohttp.streams module ---------------------- diff --git a/docs/web_reference.rst b/docs/web_reference.rst index db618867953..e3167400dac 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -543,7 +543,8 @@ StreamResponse response answers. Send *HTTP header*. You should not change any header data after - calling this method. + calling this method, except through + :attr:`Application.on_response_start` signal callbacks. .. method:: write(data) @@ -873,6 +874,13 @@ arbitrary properties for later access from :ref:`event loop` used for processing HTTP requests. + .. attribute:: on_response_start + + A :class:`~aiohttp.signals.Signal` that is fired at the beginning of + :meth:`StreamResponse.start` with parameters ``request`` and + ``response``. It can be used, for example, to add custom headers to each + response before sending. + .. method:: make_handler(**kwargs) Creates HTTP protocol factory for handling requests. From 4d8b509e13cdc5a7c3bb1251ffe3915a7e349a83 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Tue, 21 Jul 2015 09:48:40 +0100 Subject: [PATCH 07/14] Point at FunctionSignal, not Signal in `on_response_start` docs --- docs/web_reference.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/web_reference.rst b/docs/web_reference.rst index e3167400dac..5adc63f7206 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -876,8 +876,8 @@ arbitrary properties for later access from .. attribute:: on_response_start - A :class:`~aiohttp.signals.Signal` that is fired at the beginning of - :meth:`StreamResponse.start` with parameters ``request`` and + A :class:`~aiohttp.signals.FunctionSignal` that is fired at the beginning + of :meth:`StreamResponse.start` with parameters ``request`` and ``response``. It can be used, for example, to add custom headers to each response before sending. From f04fbb37800b813beb96b435f780dbc68880b4b2 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Fri, 25 Sep 2015 22:32:54 +0100 Subject: [PATCH 08/14] Remove FunctionSignals in light of #525. --- aiohttp/signals.py | 52 +++++++++++++--------------------------------- 1 file changed, 14 insertions(+), 38 deletions(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index 4e5e67f8efe..678aabd47bd 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -1,24 +1,21 @@ import abc +import functools from inspect import signature import asyncio class Signal(metaclass=abc.ABCMeta): """ - Abstract base class for signals. + Coroutine-based signal implementation To connect a callback to a signal, use the :meth:`callback` method. If you wish to pass additional arguments to your callback, use :meth:`functools.partial`. Signals can be disconnected again using - :meth:`disconnect`. Callbacks are executed in an arbitrary order. + :meth:`disconnect`. Callbacks are executed in an arbitrary order and must + be coroutines. - There are two declared concrete subclasses, :class:`FunctionSignal`, which - dispatches to plain function callbacks, and :class:`CoroutineSignal`, - which accepts coroutine functions as callbacks. - - Signals are fired using :meth:`send`, which takes named arguments. The - :meth:`send` method for :class:`CoroutineSignal` is itself a coroutine - function. + Signals are fired using the :meth:`send` coroutine, which takes named + arguments. """ def __init__(self, parameters): self._parameters = frozenset(parameters) @@ -35,6 +32,13 @@ def connect(self, receiver): """ # Check that the callback can be called with the given parameter names if __debug__: + # We suggest using functools.partial, but that hides the fact that + # they are coroutine functions. So, let's check the underlying + # function instead of the receiver itself. + func = receiver + while isinstance(func, functools.partial): + func = func.func + assert asyncio.iscoroutinefunction(func), receiver signature(receiver).bind(**{p: None for p in self._parameters}) self._receivers.add(receiver) @@ -49,39 +53,11 @@ def disconnect(self, receiver): """ self._receivers.remove(receiver) - @abc.abstractmethod + @asyncio.coroutine def send(self, **kwargs): """ Sends data to all registered receivers. """ - pass - -class FunctionSignal(Signal): - """ - A signal type that dispatches to plain functions. - - See :class:`Signal` for documentation. - """ - def connect(self, receiver): - assert not asyncio.iscoroutinefunction(receiver), receiver - super().connect(receiver) - - def send(self, **kwargs): - for receiver in self._receivers: - receiver(**kwargs) - -class CoroutineSignal(Signal): - """ - A signal type that dispatches to coroutine functions. - - See :class:`Signal` for documentation. - """ - def connect(self, receiver): - assert asyncio.iscoroutinefunction(receiver), receiver - super().connect(receiver) - - @asyncio.coroutine - def send(self, **kwargs): for receiver in self._receivers: yield from receiver(**kwargs) From 5fb868bbc1d28a28431f07dd28c34f54fb96c34f Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Fri, 25 Sep 2015 22:41:02 +0100 Subject: [PATCH 09/14] Move on_response_start firing to `prepare()` and treat it as a coroutine --- aiohttp/web.py | 4 ++-- aiohttp/web_reqrep.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 6828e569481..337a5bdf615 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -7,7 +7,7 @@ from .web_urldispatcher import * # noqa from .web_ws import * # noqa from .protocol import HttpVersion # noqa -from .signals import FunctionSignal +from .signals import Signal import asyncio @@ -197,7 +197,7 @@ def __init__(self, *, logger=web_logger, loop=None, assert asyncio.iscoroutinefunction(factory), factory self._middlewares = list(middlewares) - self.on_response_start = FunctionSignal({'request', 'response'}) + self.on_response_start = Signal({'request', 'response'}) @property def router(self): diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 5ea974a44c0..5bac1652a88 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -634,9 +634,6 @@ def _start(coding): return def start(self, request): - request.app.on_response_start.send(request=request, - response=self) - warnings.warn('use .prepare(request) instead', DeprecationWarning) resp_impl = self._start_pre_check(request) if resp_impl is not None: @@ -649,6 +646,8 @@ def prepare(self, request): resp_impl = self._start_pre_check(request) if resp_impl is not None: return resp_impl + yield from request.app.on_response_start.send(request=request, + response=self) return self._start(request) From cc2efbd0a4daa84d14cd1be801ca5e4a411e452d Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Fri, 25 Sep 2015 23:14:42 +0100 Subject: [PATCH 10/14] Raise TypeError on non-coroutine functions, to match signature mismatches --- aiohttp/signals.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index 678aabd47bd..dfd779d8a9f 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -38,7 +38,8 @@ def connect(self, receiver): func = receiver while isinstance(func, functools.partial): func = func.func - assert asyncio.iscoroutinefunction(func), receiver + if not asyncio.iscoroutinefunction(func): + raise TypeError("{} is not a coroutine function".format(receiver)) signature(receiver).bind(**{p: None for p in self._parameters}) self._receivers.add(receiver) From 9170057b024d0ecdd2a0ab0f92d488ec2429aa3d Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Fri, 25 Sep 2015 23:15:34 +0100 Subject: [PATCH 11/14] Working tests again. --- tests/test_signals.py | 25 ++++++++++++------------- tests/test_web_exceptions.py | 3 ++- tests/test_web_request.py | 2 ++ tests/test_web_response.py | 4 +++- tests/test_web_websocket.py | 3 ++- 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/test_signals.py b/tests/test_signals.py index 732908e30b0..e0789420ce7 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -2,7 +2,7 @@ import unittest from unittest import mock from aiohttp.multidict import CIMultiDict -from aiohttp.signals import Signal, FunctionSignal, CoroutineSignal +from aiohttp.signals import Signal from aiohttp.web import Application from aiohttp.web import Request, StreamResponse, Response from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 @@ -32,10 +32,10 @@ def request_from_message(self, message, app=None): return req def test_callback_valid(self): - signal = FunctionSignal({'foo', 'bar'}) + signal = Signal({'foo', 'bar'}) # All these are suitable - good_callbacks = [ + good_callbacks = map(asyncio.coroutine, [ (lambda foo, bar: None), (lambda *, foo, bar: None), (lambda foo, bar, **kwargs: None), @@ -44,24 +44,24 @@ def test_callback_valid(self): (lambda foo=None, bar=None: None), (lambda foo, bar=None, *, baz=None: None), (lambda **kwargs: None), - ] + ]) for callback in good_callbacks: signal.connect(callback) def test_callback_invalid(self): - signal = FunctionSignal({'foo', 'bar'}) + signal = Signal({'foo', 'bar'}) # All these are unsuitable - bad_callbacks = [ + bad_callbacks = map(asyncio.coroutine, [ (lambda foo: None), (lambda foo, bar, baz: None), - ] + ]) for callback in bad_callbacks: with self.assertRaises(TypeError): signal.connect(callback) def test_add_response_start_signal_handler(self): - callback = lambda request, response: None + callback = asyncio.coroutine(lambda request, response: None) app = Application(loop=self.loop) app.on_response_start.connect(callback) @@ -72,7 +72,7 @@ def test_add_signal_handler_not_a_callable(self): app.on_response_start.connect(callback) def test_function_signal_dispatch(self): - signal = CoroutineSignal({'foo', 'bar'}) + signal = Signal({'foo', 'bar'}) kwargs = {'foo': 1, 'bar': 2} callback_mock = mock.Mock() @@ -83,16 +83,15 @@ def test_function_signal_dispatch(self): self.loop.run_until_complete(signal.send(**kwargs)) callback_mock.assert_called_once_with(**kwargs) - def test_response_start(self): + def test_response_prepare(self): callback = mock.Mock() - callback._is_coroutine = False app = Application(loop=self.loop) - app.on_response_start.connect(callback) + app.on_response_start.connect(asyncio.coroutine(callback)) request = self.make_request('GET', '/', app=app) response = Response(body=b'') - response.start(request) + self.loop.run_until_complete(response.prepare(request)) callback.assert_called_once_with(request=request, response=response) diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 611562619fd..8699d5b45be 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -6,7 +6,7 @@ from aiohttp.web import Request from aiohttp.protocol import RawRequestMessage, HttpVersion11 -from aiohttp import web +from aiohttp import signals, web class TestHTTPExceptions(unittest.TestCase): @@ -32,6 +32,7 @@ def append(self, data): def make_request(self, method='GET', path='/', headers=CIMultiDict()): self.app = mock.Mock() + self.app.on_response_start = signals.Signal({'request', 'response'}) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) req = Request(self.app, message, self.payload, diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 60c77bf3af5..378b8c0bf90 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -1,6 +1,7 @@ import asyncio import unittest from unittest import mock +from aiohttp.signals import Signal from aiohttp.web import Request from aiohttp.multidict import MultiDict, CIMultiDict from aiohttp.protocol import HttpVersion @@ -23,6 +24,7 @@ def make_request(self, method, path, headers=CIMultiDict(), *, if version < HttpVersion(1, 1): closing = True self.app = mock.Mock() + self.app.on_response_start = Signal({'request', 'response'}) message = RawRequestMessage(method, path, version, headers, closing, False) self.payload = mock.Mock() diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 11fb5454871..72089490516 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -2,7 +2,7 @@ import datetime import unittest from unittest import mock -from aiohttp import hdrs +from aiohttp import hdrs, signals from aiohttp.multidict import CIMultiDict from aiohttp.web import ContentCoding, Request, StreamResponse, Response from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 @@ -26,6 +26,7 @@ def make_request(self, method, path, headers=CIMultiDict(), def request_from_message(self, message): self.app = mock.Mock() + self.app.on_response_start = signals.Signal({'request', 'response'}) self.payload = mock.Mock() self.transport = mock.Mock() self.reader = mock.Mock() @@ -526,6 +527,7 @@ def tearDown(self): def make_request(self, method, path, headers=CIMultiDict()): self.app = mock.Mock() + self.app.on_response_start = signals.Signal({'request', 'response'}) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) self.payload = mock.Mock() diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index bc8a7fb11a0..eec5d5fede1 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -5,7 +5,7 @@ from aiohttp.web import ( MsgType, Request, WebSocketResponse, HTTPMethodNotAllowed, HTTPBadRequest) from aiohttp.protocol import RawRequestMessage, HttpVersion11 -from aiohttp import errors, websocket +from aiohttp import errors, signals, websocket class TestWebWebSocket(unittest.TestCase): @@ -35,6 +35,7 @@ def make_request(self, method, path, headers=None): self.reader = mock.Mock() self.writer = mock.Mock() self.app.loop = self.loop + self.app.on_response_start = signals.Signal({'request', 'response'}) req = Request(self.app, message, self.payload, self.transport, self.reader, self.writer) return req From 5825da33402943686067a024b716d66cf801999a Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 28 Sep 2015 21:44:43 +0100 Subject: [PATCH 12/14] Signal now based on list; still does signature checking --- aiohttp/signals.py | 68 +++++++++++++++++++----------------- aiohttp/web.py | 2 +- aiohttp/web_reqrep.py | 4 +-- tests/test_signals.py | 20 +++++------ tests/test_web_exceptions.py | 2 +- tests/test_web_request.py | 2 +- tests/test_web_response.py | 4 +-- tests/test_web_websocket.py | 2 +- 8 files changed, 53 insertions(+), 51 deletions(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index dfd779d8a9f..403601dbe97 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -4,61 +4,63 @@ import asyncio -class Signal(metaclass=abc.ABCMeta): +class Signal(list): """ Coroutine-based signal implementation - To connect a callback to a signal, use the :meth:`callback` method. If you - wish to pass additional arguments to your callback, - use :meth:`functools.partial`. Signals can be disconnected again using - :meth:`disconnect`. Callbacks are executed in an arbitrary order and must - be coroutines. + To connect a callback to a signal, use any list method. If wish to pass + additional arguments to your callback, use :meth:`functools.partial`. Signals are fired using the :meth:`send` coroutine, which takes named arguments. """ - def __init__(self, parameters): - self._parameters = frozenset(parameters) - self._receivers = set() + def __init__(self, *args, parameters=None): + self._parameters = parameters + if args: + self.extend(args[0]) - def connect(self, receiver): - """ - Connect a receiver. - - :param collections.abc.Callable receiver: A function to be called - whenever the signal is fired. - :raises TypeError: if ``receiver`` isn't a callable, or doesn't have - a call signature that supports the signals parameters. - """ - # Check that the callback can be called with the given parameter names - if __debug__: - # We suggest using functools.partial, but that hides the fact that - # they are coroutine functions. So, let's check the underlying - # function instead of the receiver itself. + def _check_signature(self, receiver): + if self._parameters is not None: func = receiver while isinstance(func, functools.partial): func = func.func if not asyncio.iscoroutinefunction(func): raise TypeError("{} is not a coroutine function".format(receiver)) signature(receiver).bind(**{p: None for p in self._parameters}) - self._receivers.add(receiver) + return True - def disconnect(self, receiver): - """ - Disconnect a receiver. + # Only override these methods to check signatures if not optimised. + if __debug__: + def __iadd__(self, other): + assert all(map(self._check_signature, other)) + super().__iadd__(other) - :param collections.abc.Callable receiver: A function to no longer - be called whenever the signal is fired. + def __setitem__(self, key, value): + if isinstance(key, slice): + value = list(value) + assert all(map(self._check_signature, other)) + else: + assert self._check_signature(value) + super().__setitem__(key, value) - :raises KeyError: if the receiver wasn't already registered. - """ - self._receivers.remove(receiver) + def insert(self, index, obj): + assert self._check_signature(obj) + super().insert(index, obj) + + def append(self, obj): + assert self._check_signature(obj) + super().append(obj) + + def extend(self, other): + other = list(other) + assert all(map(self._check_signature, other)) + super().extend(other) @asyncio.coroutine def send(self, **kwargs): """ Sends data to all registered receivers. """ - for receiver in self._receivers: + for receiver in self: yield from receiver(**kwargs) diff --git a/aiohttp/web.py b/aiohttp/web.py index 337a5bdf615..ac5208826b3 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -197,7 +197,7 @@ def __init__(self, *, logger=web_logger, loop=None, assert asyncio.iscoroutinefunction(factory), factory self._middlewares = list(middlewares) - self.on_response_start = Signal({'request', 'response'}) + self.on_response_prepare = Signal(parameters={'request', 'response'}) @property def router(self): diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 5bac1652a88..9794658f5f9 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -646,8 +646,8 @@ def prepare(self, request): resp_impl = self._start_pre_check(request) if resp_impl is not None: return resp_impl - yield from request.app.on_response_start.send(request=request, - response=self) + yield from request.app.on_response_prepare.send(request=request, + response=self) return self._start(request) diff --git a/tests/test_signals.py b/tests/test_signals.py index e0789420ce7..03d4dd9dfab 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -32,7 +32,7 @@ def request_from_message(self, message, app=None): return req def test_callback_valid(self): - signal = Signal({'foo', 'bar'}) + signal = Signal(parameters={'foo', 'bar'}) # All these are suitable good_callbacks = map(asyncio.coroutine, [ @@ -46,10 +46,10 @@ def test_callback_valid(self): (lambda **kwargs: None), ]) for callback in good_callbacks: - signal.connect(callback) + signal.append(callback) def test_callback_invalid(self): - signal = Signal({'foo', 'bar'}) + signal = Signal(parameters={'foo', 'bar'}) # All these are unsuitable bad_callbacks = map(asyncio.coroutine, [ @@ -58,27 +58,27 @@ def test_callback_invalid(self): ]) for callback in bad_callbacks: with self.assertRaises(TypeError): - signal.connect(callback) + signal.send(callback) - def test_add_response_start_signal_handler(self): + def test_add_response_prepare_signal_handler(self): callback = asyncio.coroutine(lambda request, response: None) app = Application(loop=self.loop) - app.on_response_start.connect(callback) + app.on_response_prepare.append(callback) def test_add_signal_handler_not_a_callable(self): callback = True app = Application(loop=self.loop) with self.assertRaises(TypeError): - app.on_response_start.connect(callback) + app.on_response_prepare.append(callback) def test_function_signal_dispatch(self): - signal = Signal({'foo', 'bar'}) + signal = Signal(parameters={'foo', 'bar'}) kwargs = {'foo': 1, 'bar': 2} callback_mock = mock.Mock() callback = asyncio.coroutine(callback_mock) - signal.connect(callback) + signal.append(callback) self.loop.run_until_complete(signal.send(**kwargs)) callback_mock.assert_called_once_with(**kwargs) @@ -87,7 +87,7 @@ def test_response_prepare(self): callback = mock.Mock() app = Application(loop=self.loop) - app.on_response_start.connect(asyncio.coroutine(callback)) + app.on_response_prepare.append(asyncio.coroutine(callback)) request = self.make_request('GET', '/', app=app) response = Response(body=b'') diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 8699d5b45be..eb2b7e93514 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -32,7 +32,7 @@ def append(self, data): def make_request(self, method='GET', path='/', headers=CIMultiDict()): self.app = mock.Mock() - self.app.on_response_start = signals.Signal({'request', 'response'}) + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) req = Request(self.app, message, self.payload, diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 378b8c0bf90..f0307b02fbd 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -24,7 +24,7 @@ def make_request(self, method, path, headers=CIMultiDict(), *, if version < HttpVersion(1, 1): closing = True self.app = mock.Mock() - self.app.on_response_start = Signal({'request', 'response'}) + self.app.on_response_prepare = Signal(parameters={'request', 'response'}) message = RawRequestMessage(method, path, version, headers, closing, False) self.payload = mock.Mock() diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 72089490516..203da442497 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -26,7 +26,7 @@ def make_request(self, method, path, headers=CIMultiDict(), def request_from_message(self, message): self.app = mock.Mock() - self.app.on_response_start = signals.Signal({'request', 'response'}) + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) self.payload = mock.Mock() self.transport = mock.Mock() self.reader = mock.Mock() @@ -527,7 +527,7 @@ def tearDown(self): def make_request(self, method, path, headers=CIMultiDict()): self.app = mock.Mock() - self.app.on_response_start = signals.Signal({'request', 'response'}) + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) self.payload = mock.Mock() diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index eec5d5fede1..4c5ec1d8e0b 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -35,7 +35,7 @@ def make_request(self, method, path, headers=None): self.reader = mock.Mock() self.writer = mock.Mock() self.app.loop = self.loop - self.app.on_response_start = signals.Signal({'request', 'response'}) + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) req = Request(self.app, message, self.payload, self.transport, self.reader, self.writer) return req From f5b98ac5a313d325782b64dbefdbf660062d13ed Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 28 Sep 2015 22:09:53 +0100 Subject: [PATCH 13/14] Drop requirement for signal receivers to be coroutines (but they still can) --- aiohttp/signals.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index 403601dbe97..b446bbb5c00 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -21,11 +21,6 @@ def __init__(self, *args, parameters=None): def _check_signature(self, receiver): if self._parameters is not None: - func = receiver - while isinstance(func, functools.partial): - func = func.func - if not asyncio.iscoroutinefunction(func): - raise TypeError("{} is not a coroutine function".format(receiver)) signature(receiver).bind(**{p: None for p in self._parameters}) return True @@ -62,5 +57,6 @@ def send(self, **kwargs): Sends data to all registered receivers. """ for receiver in self: - yield from receiver(**kwargs) - + res = receiver(**kwargs) + if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future): + yield from res From 322f650cb39affe40a4fe528e6a791dfd069240e Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 28 Sep 2015 22:10:54 +0100 Subject: [PATCH 14/14] Fix variable name in signature check call --- aiohttp/signals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/signals.py b/aiohttp/signals.py index b446bbb5c00..24042aff7b4 100644 --- a/aiohttp/signals.py +++ b/aiohttp/signals.py @@ -33,7 +33,7 @@ def __iadd__(self, other): def __setitem__(self, key, value): if isinstance(key, slice): value = list(value) - assert all(map(self._check_signature, other)) + assert all(map(self._check_signature, value)) else: assert self._check_signature(value) super().__setitem__(key, value)