diff --git a/Makefile b/Makefile index fd6335ad93b..3baf907fb9d 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ vtest: flake develop cov cover coverage: tox -cov-dev: flake develop +cov-dev: develop @coverage erase @coverage run -m pytest -s tests @mv .coverage .coverage.accel diff --git a/aiohttp/signals.py b/aiohttp/signals.py new file mode 100644 index 00000000000..5093bb8498c --- /dev/null +++ b/aiohttp/signals.py @@ -0,0 +1,71 @@ +import asyncio +from itertools import count + + +class BaseSignal(list): + + @asyncio.coroutine + def _send(self, *args, **kwargs): + for receiver in self: + res = receiver(*args, **kwargs) + if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future): + yield from res + + def copy(self): + raise NotImplementedError("copy() is forbidden") + + def sort(self): + raise NotImplementedError("sort() is forbidden") + + +class Signal(BaseSignal): + """Coroutine-based signal implementation. + + To connect a callback to a signal, use any list method. + + Signals are fired using the :meth:`send` coroutine, which takes named + arguments. + """ + + def __init__(self, app): + super().__init__() + self._app = app + klass = self.__class__ + self._name = klass.__module__ + ':' + klass.__qualname__ + self._pre = app.on_pre_signal + self._post = app.on_post_signal + + @asyncio.coroutine + def send(self, *args, **kwargs): + """ + Sends data to all registered receivers. + """ + ordinal = None + debug = self._app._debug + if debug: + ordinal = self._pre.ordinal() + yield from self._pre.send(ordinal, self._name, *args, **kwargs) + yield from self._send(*args, **kwargs) + if debug: + yield from self._post.send(ordinal, self._name, *args, **kwargs) + + +class DebugSignal(BaseSignal): + + @asyncio.coroutine + def send(self, ordinal, name, *args, **kwargs): + yield from self._send(ordinal, name, *args, **kwargs) + + +class PreSignal(DebugSignal): + + def __init__(self): + super().__init__() + self._counter = count(1) + + def ordinal(self): + return next(self._counter) + + +class PostSignal(DebugSignal): + pass diff --git a/aiohttp/web.py b/aiohttp/web.py index 706fbb04ed6..973ec3d322a 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -7,6 +7,7 @@ from .web_urldispatcher import * # noqa from .web_ws import * # noqa from .protocol import HttpVersion # noqa +from .signals import Signal, PreSignal, PostSignal import asyncio @@ -179,13 +180,14 @@ class Application(dict): def __init__(self, *, logger=web_logger, loop=None, router=None, handler_factory=RequestHandlerFactory, - middlewares=()): + middlewares=(), debug=False): if loop is None: loop = asyncio.get_event_loop() if router is None: router = UrlDispatcher() assert isinstance(router, AbstractRouter), router + self._debug = debug self._router = router self._handler_factory = handler_factory self._finish_callbacks = [] @@ -196,6 +198,26 @@ def __init__(self, *, logger=web_logger, loop=None, assert asyncio.iscoroutinefunction(factory), factory self._middlewares = list(middlewares) + self._on_pre_signal = PreSignal() + self._on_post_signal = PostSignal() + self._on_response_prepare = Signal(self) + + @property + def debug(self): + return self._debug + + @property + def on_response_prepare(self): + return self._on_response_prepare + + @property + def on_pre_signal(self): + return self._on_pre_signal + + @property + def on_post_signal(self): + return self._on_post_signal + @property def router(self): return self._router diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index 3c07ef89621..9794658f5f9 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -646,6 +646,8 @@ def prepare(self, request): resp_impl = self._start_pre_check(request) if resp_impl is not None: return resp_impl + yield from request.app.on_response_prepare.send(request=request, + response=self) return self._start(request) diff --git a/docs/api.rst b/docs/api.rst index 04656ec965a..4dadf958f1d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -49,6 +49,14 @@ aiohttp.protocol module :undoc-members: :show-inheritance: +aiohttp.signals module +---------------------- + +.. automodule:: aiohttp.signals + :members: + :undoc-members: + :show-inheritance: + aiohttp.streams module ---------------------- diff --git a/docs/web_reference.rst b/docs/web_reference.rst index c9369717b6c..41d8a7443f5 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -560,6 +560,10 @@ StreamResponse Use :meth:`prepare` instead. + .. warning:: The method doesn't call + :attr:`web.Application.on_response_prepare` signal, use + :meth:`prepare` instead. + .. coroutinemethod:: prepare(request) :param aiohttp.web.Request request: HTTP request object, that the @@ -568,6 +572,9 @@ StreamResponse Send *HTTP header*. You should not change any header data after calling this method. + The coroutine calls :attr:`web.Application.on_response_prepare` + signal handlers. + .. versionadded:: 0.18 .. method:: write(data) @@ -920,6 +927,13 @@ arbitrary properties for later access from :ref:`event loop` used for processing HTTP requests. + .. attribute:: on_response_prepare + + A :class:`~aiohttp.signals.Signal` that is fired at the beginning + of :meth:`StreamResponse.prepare` with parameters *request* and + *response*. It can be used, for example, to add custom headers to each + response before sending. + .. method:: make_handler(**kwargs) Creates HTTP protocol factory for handling requests. diff --git a/tests/test_signals.py b/tests/test_signals.py new file mode 100644 index 00000000000..ac8cad65b6b --- /dev/null +++ b/tests/test_signals.py @@ -0,0 +1,145 @@ +import asyncio +from unittest import mock +from aiohttp.multidict import CIMultiDict +from aiohttp.signals import Signal +from aiohttp.web import Application +from aiohttp.web import Request, Response +from aiohttp.protocol import HttpVersion11 +from aiohttp.protocol import RawRequestMessage + +import pytest + + +@pytest.fixture +def app(loop): + return Application(loop=loop) + + +@pytest.fixture +def debug_app(loop): + return Application(loop=loop, debug=True) + + +def make_request(app, method, path, headers=CIMultiDict()): + message = RawRequestMessage(method, path, HttpVersion11, headers, + False, False) + return request_from_message(message, app) + + +def request_from_message(message, app): + payload = mock.Mock() + transport = mock.Mock() + reader = mock.Mock() + writer = mock.Mock() + req = Request(app, message, payload, + transport, reader, writer) + return req + + +def test_add_response_prepare_signal_handler(loop, app): + callback = asyncio.coroutine(lambda request, response: None) + app.on_response_prepare.append(callback) + + +def test_add_signal_handler_not_a_callable(loop, app): + callback = True + app.on_response_prepare.append(callback) + with pytest.raises(TypeError): + app.on_response_prepare(None, None) + + +def test_function_signal_dispatch(loop, app): + signal = Signal(app) + kwargs = {'foo': 1, 'bar': 2} + + callback_mock = mock.Mock() + + @asyncio.coroutine + def callback(**kwargs): + callback_mock(**kwargs) + + signal.append(callback) + + loop.run_until_complete(signal.send(**kwargs)) + callback_mock.assert_called_once_with(**kwargs) + + +def test_function_signal_dispatch2(loop, app): + signal = Signal(app) + args = {'a', 'b'} + kwargs = {'foo': 1, 'bar': 2} + + callback_mock = mock.Mock() + + @asyncio.coroutine + def callback(*args, **kwargs): + callback_mock(*args, **kwargs) + + signal.append(callback) + + loop.run_until_complete(signal.send(*args, **kwargs)) + callback_mock.assert_called_once_with(*args, **kwargs) + + +def test_response_prepare(loop, app): + callback = mock.Mock() + + @asyncio.coroutine + def cb(*args, **kwargs): + callback(*args, **kwargs) + + app.on_response_prepare.append(cb) + + request = make_request(app, 'GET', '/') + response = Response(body=b'') + loop.run_until_complete(response.prepare(request)) + + callback.assert_called_once_with(request=request, + response=response) + + +def test_non_coroutine(loop, app): + signal = Signal(app) + kwargs = {'foo': 1, 'bar': 2} + + callback = mock.Mock() + + signal.append(callback) + + loop.run_until_complete(signal.send(**kwargs)) + callback.assert_called_once_with(**kwargs) + + +def test_copy_forbidden(app): + signal = Signal(app) + with pytest.raises(NotImplementedError): + signal.copy() + + +def test_sort_forbidden(app): + l1 = lambda: None + l2 = lambda: None + l3 = lambda: None + signal = Signal(app) + signal.extend([l1, l2, l3]) + with pytest.raises(NotImplementedError): + signal.sort() + assert signal == [l1, l2, l3] + + +def test_debug_signal(loop, debug_app): + assert debug_app.debug, "Should be True" + signal = Signal(debug_app) + + callback = mock.Mock() + pre = mock.Mock() + post = mock.Mock() + + signal.append(callback) + debug_app.on_pre_signal.append(pre) + debug_app.on_post_signal.append(post) + + loop.run_until_complete(signal.send(1, a=2)) + callback.assert_called_once_with(1, a=2) + pre.assert_called_once_with(1, 'aiohttp.signals:Signal', 1, a=2) + post.assert_called_once_with(1, 'aiohttp.signals:Signal', 1, a=2) diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index 611562619fd..2107ee1714b 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -6,7 +6,7 @@ from aiohttp.web import Request from aiohttp.protocol import RawRequestMessage, HttpVersion11 -from aiohttp import web +from aiohttp import signals, web class TestHTTPExceptions(unittest.TestCase): @@ -32,6 +32,8 @@ def append(self, data): def make_request(self, method='GET', path='/', headers=CIMultiDict()): self.app = mock.Mock() + self.app._debug = False + self.app.on_response_prepare = signals.Signal(self.app) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) req = Request(self.app, message, self.payload, diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 60c77bf3af5..40cf6730c1e 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -1,6 +1,7 @@ import asyncio import unittest from unittest import mock +from aiohttp.signals import Signal from aiohttp.web import Request from aiohttp.multidict import MultiDict, CIMultiDict from aiohttp.protocol import HttpVersion @@ -23,6 +24,8 @@ def make_request(self, method, path, headers=CIMultiDict(), *, if version < HttpVersion(1, 1): closing = True self.app = mock.Mock() + self.app._debug = False + self.app.on_response_prepare = Signal(self.app) message = RawRequestMessage(method, path, version, headers, closing, False) self.payload = mock.Mock() diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 11fb5454871..c50f6d224de 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -2,7 +2,7 @@ import datetime import unittest from unittest import mock -from aiohttp import hdrs +from aiohttp import hdrs, signals from aiohttp.multidict import CIMultiDict from aiohttp.web import ContentCoding, Request, StreamResponse, Response from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10 @@ -26,6 +26,8 @@ def make_request(self, method, path, headers=CIMultiDict(), def request_from_message(self, message): self.app = mock.Mock() + self.app._debug = False + self.app.on_response_prepare = signals.Signal(self.app) self.payload = mock.Mock() self.transport = mock.Mock() self.reader = mock.Mock() @@ -514,6 +516,16 @@ def test_start_twice(self, ResponseImpl): impl2 = resp.start(req) self.assertIs(impl1, impl2) + def test_prepare_calls_signal(self): + req = self.make_request('GET', '/') + resp = StreamResponse() + + sig = mock.Mock() + self.app.on_response_prepare.append(sig) + self.loop.run_until_complete(resp.prepare(req)) + + sig.assert_called_with(request=req, response=resp) + class TestResponse(unittest.TestCase): @@ -526,6 +538,8 @@ def tearDown(self): def make_request(self, method, path, headers=CIMultiDict()): self.app = mock.Mock() + self.app._debug = False + self.app.on_response_prepare = signals.Signal(self.app) message = RawRequestMessage(method, path, HttpVersion11, headers, False, False) self.payload = mock.Mock() diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 88a8557e7fa..6af33546341 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -5,7 +5,7 @@ from aiohttp.web import ( MsgType, Request, WebSocketResponse, HTTPMethodNotAllowed, HTTPBadRequest) from aiohttp.protocol import RawRequestMessage, HttpVersion11 -from aiohttp import errors, websocket +from aiohttp import errors, signals, websocket class TestWebWebSocket(unittest.TestCase): @@ -19,6 +19,7 @@ def tearDown(self): def make_request(self, method, path, headers=None, protocols=False): self.app = mock.Mock() + self.app._debug = False if headers is None: headers = CIMultiDict( {'HOST': 'server.example.com', @@ -37,6 +38,7 @@ def make_request(self, method, path, headers=None, protocols=False): self.reader = mock.Mock() self.writer = mock.Mock() self.app.loop = self.loop + self.app.on_response_prepare = signals.Signal(self.app) req = Request(self.app, message, self.payload, self.transport, self.reader, self.writer) return req