diff --git a/aiohttp/signals.py b/aiohttp/signals.py new file mode 100644 index 00000000000..24042aff7b4 --- /dev/null +++ b/aiohttp/signals.py @@ -0,0 +1,62 @@ +import abc +import functools +from inspect import signature + +import asyncio + +class Signal(list): + """ + Coroutine-based signal implementation + + To connect a callback to a signal, use any list method. If wish to pass + additional arguments to your callback, use :meth:`functools.partial`. + + Signals are fired using the :meth:`send` coroutine, which takes named + arguments. + """ + def __init__(self, *args, parameters=None): + self._parameters = parameters + if args: + self.extend(args[0]) + + def _check_signature(self, receiver): + if self._parameters is not None: + signature(receiver).bind(**{p: None for p in self._parameters}) + return True + + # Only override these methods to check signatures if not optimised. + if __debug__: + def __iadd__(self, other): + assert all(map(self._check_signature, other)) + super().__iadd__(other) + + def __setitem__(self, key, value): + if isinstance(key, slice): + value = list(value) + assert all(map(self._check_signature, value)) + else: + assert self._check_signature(value) + super().__setitem__(key, value) + + def insert(self, index, obj): + assert self._check_signature(obj) + super().insert(index, obj) + + def append(self, obj): + assert self._check_signature(obj) + super().append(obj) + + def extend(self, other): + other = list(other) + assert all(map(self._check_signature, other)) + super().extend(other) + + @asyncio.coroutine + def send(self, **kwargs): + """ + Sends data to all registered receivers. + """ + for receiver in self: + res = receiver(**kwargs) + if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future): + yield from res diff --git a/aiohttp/web.py b/aiohttp/web.py index 706fbb04ed6..ac5208826b3 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -7,6 +7,7 @@ from .web_urldispatcher import * # noqa from .web_ws import * # noqa from .protocol import HttpVersion # noqa +from .signals import Signal import asyncio @@ -196,6 +197,8 @@ def __init__(self, *, logger=web_logger, loop=None, assert asyncio.iscoroutinefunction(factory), factory self._middlewares = list(middlewares) + self.on_response_prepare = Signal(parameters={'request', 'response'}) + @property def router(self): return self._router diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 3c07ef89621..9794658f5f9 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -646,6 +646,8 @@ def prepare(self, request): resp_impl = self._start_pre_check(request) if resp_impl is not None: return resp_impl + yield from request.app.on_response_prepare.send(request=request, + response=self) return self._start(request) diff --git a/docs/api.rst b/docs/api.rst index 04656ec965a..4dadf958f1d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -49,6 +49,14 @@ aiohttp.protocol module :undoc-members: :show-inheritance: +aiohttp.signals module +---------------------- + +.. automodule:: aiohttp.signals + :members: + :undoc-members: + :show-inheritance: + aiohttp.streams module ---------------------- diff --git a/docs/web_reference.rst b/docs/web_reference.rst index ffad1ad6b96..bd7b56bdc87 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -555,7 +555,8 @@ StreamResponse response answers. Send *HTTP header*. You should not change any header data after - calling this method. + calling this method, except through + :attr:`Application.on_response_start` signal callbacks. .. deprecated:: 0.18 @@ -922,6 +923,13 @@ arbitrary properties for later access from :ref:`event loop` used for processing HTTP requests. + .. attribute:: on_response_start + + A :class:`~aiohttp.signals.FunctionSignal` that is fired at the beginning + of :meth:`StreamResponse.start` with parameters ``request`` and + ``response``. It can be used, for example, to add custom headers to each + response before sending. + .. method:: make_handler(**kwargs) Creates HTTP protocol factory for handling requests. diff --git a/tests/test_signals.py b/tests/test_signals.py new file mode 100644 index 00000000000..03d4dd9dfab --- /dev/null +++ b/tests/test_signals.py @@ -0,0 +1,98 @@ +import asyncio +import unittest +from unittest import mock +from aiohttp.multidict import CIMultiDict +from aiohttp.signals import Signal +from aiohttp.web import Application +from aiohttp.web import Request, StreamResponse, Response +from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 +from aiohttp.protocol import RawRequestMessage + +class TestSignals(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def make_request(self, method, path, headers=CIMultiDict(), app=None): + message = RawRequestMessage(method, path, HttpVersion11, headers, + False, False) + return self.request_from_message(message, app) + + def request_from_message(self, message, app=None): + self.app = app if app is not None else mock.Mock() + self.payload = mock.Mock() + self.transport = mock.Mock() + self.reader = mock.Mock() + self.writer = mock.Mock() + req = Request(self.app, message, self.payload, + self.transport, self.reader, self.writer) + return req + + def test_callback_valid(self): + signal = Signal(parameters={'foo', 'bar'}) + + # All these are suitable + good_callbacks = map(asyncio.coroutine, [ + (lambda foo, bar: None), + (lambda *, foo, bar: None), + (lambda foo, bar, **kwargs: None), + (lambda foo, bar, baz=None: None), + (lambda baz=None, *, foo, bar: None), + (lambda foo=None, bar=None: None), + (lambda foo, bar=None, *, baz=None: None), + (lambda **kwargs: None), + ]) + for callback in good_callbacks: + signal.append(callback) + + def test_callback_invalid(self): + signal = Signal(parameters={'foo', 'bar'}) + + # All these are unsuitable + bad_callbacks = map(asyncio.coroutine, [ + (lambda foo: None), + (lambda foo, bar, baz: None), + ]) + for callback in bad_callbacks: + with self.assertRaises(TypeError): + signal.send(callback) + + def test_add_response_prepare_signal_handler(self): + callback = asyncio.coroutine(lambda request, response: None) + app = Application(loop=self.loop) + app.on_response_prepare.append(callback) + + def test_add_signal_handler_not_a_callable(self): + callback = True + app = Application(loop=self.loop) + with self.assertRaises(TypeError): + app.on_response_prepare.append(callback) + + def test_function_signal_dispatch(self): + signal = Signal(parameters={'foo', 'bar'}) + kwargs = {'foo': 1, 'bar': 2} + + callback_mock = mock.Mock() + callback = asyncio.coroutine(callback_mock) + + signal.append(callback) + + self.loop.run_until_complete(signal.send(**kwargs)) + callback_mock.assert_called_once_with(**kwargs) + + def test_response_prepare(self): + callback = mock.Mock() + + app = Application(loop=self.loop) + app.on_response_prepare.append(asyncio.coroutine(callback)) + + request = self.make_request('GET', '/', app=app) + response = Response(body=b'') + self.loop.run_until_complete(response.prepare(request)) + + callback.assert_called_once_with(request=request, + response=response) + diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 611562619fd..eb2b7e93514 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -6,7 +6,7 @@ from aiohttp.web import Request from aiohttp.protocol import RawRequestMessage, HttpVersion11 -from aiohttp import web +from aiohttp import signals, web class TestHTTPExceptions(unittest.TestCase): @@ -32,6 +32,7 @@ def append(self, data): def make_request(self, method='GET', path='/', headers=CIMultiDict()): self.app = mock.Mock() + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) req = Request(self.app, message, self.payload, diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 60c77bf3af5..f0307b02fbd 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -1,6 +1,7 @@ import asyncio import unittest from unittest import mock +from aiohttp.signals import Signal from aiohttp.web import Request from aiohttp.multidict import MultiDict, CIMultiDict from aiohttp.protocol import HttpVersion @@ -23,6 +24,7 @@ def make_request(self, method, path, headers=CIMultiDict(), *, if version < HttpVersion(1, 1): closing = True self.app = mock.Mock() + self.app.on_response_prepare = Signal(parameters={'request', 'response'}) message = RawRequestMessage(method, path, version, headers, closing, False) self.payload = mock.Mock() diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 11fb5454871..203da442497 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -2,7 +2,7 @@ import datetime import unittest from unittest import mock -from aiohttp import hdrs +from aiohttp import hdrs, signals from aiohttp.multidict import CIMultiDict from aiohttp.web import ContentCoding, Request, StreamResponse, Response from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 @@ -26,6 +26,7 @@ def make_request(self, method, path, headers=CIMultiDict(), def request_from_message(self, message): self.app = mock.Mock() + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) self.payload = mock.Mock() self.transport = mock.Mock() self.reader = mock.Mock() @@ -526,6 +527,7 @@ def tearDown(self): def make_request(self, method, path, headers=CIMultiDict()): self.app = mock.Mock() + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) self.payload = mock.Mock() diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 88a8557e7fa..800ef949ae6 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -5,7 +5,7 @@ from aiohttp.web import ( MsgType, Request, WebSocketResponse, HTTPMethodNotAllowed, HTTPBadRequest) from aiohttp.protocol import RawRequestMessage, HttpVersion11 -from aiohttp import errors, websocket +from aiohttp import errors, signals, websocket class TestWebWebSocket(unittest.TestCase): @@ -37,6 +37,7 @@ def make_request(self, method, path, headers=None, protocols=False): self.reader = mock.Mock() self.writer = mock.Mock() self.app.loop = self.loop + self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'}) req = Request(self.app, message, self.payload, self.transport, self.reader, self.writer) return req