Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Signals implementation #439

Merged
merged 16 commits into from
Oct 16, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions aiohttp/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import abc
import functools
from inspect import signature

import asyncio

class Signal(list):
"""
Coroutine-based signal implementation

To connect a callback to a signal, use any list method. If wish to pass
additional arguments to your callback, use :meth:`functools.partial`.

Signals are fired using the :meth:`send` coroutine, which takes named
arguments.
"""
def __init__(self, *args, parameters=None):
self._parameters = parameters
if args:
self.extend(args[0])

def _check_signature(self, receiver):
if self._parameters is not None:
signature(receiver).bind(**{p: None for p in self._parameters})
return True

# Only override these methods to check signatures if not optimised.
if __debug__:
def __iadd__(self, other):
assert all(map(self._check_signature, other))
super().__iadd__(other)

def __setitem__(self, key, value):
if isinstance(key, slice):
value = list(value)
assert all(map(self._check_signature, value))
else:
assert self._check_signature(value)
super().__setitem__(key, value)

def insert(self, index, obj):
assert self._check_signature(obj)
super().insert(index, obj)

def append(self, obj):
assert self._check_signature(obj)
super().append(obj)

def extend(self, other):
other = list(other)
assert all(map(self._check_signature, other))
super().extend(other)

@asyncio.coroutine
def send(self, **kwargs):
"""
Sends data to all registered receivers.
"""
for receiver in self:
res = receiver(**kwargs)
if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future):
yield from res
3 changes: 3 additions & 0 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .web_urldispatcher import * # noqa
from .web_ws import * # noqa
from .protocol import HttpVersion # noqa
from .signals import Signal


import asyncio
Expand Down Expand Up @@ -196,6 +197,8 @@ def __init__(self, *, logger=web_logger, loop=None,
assert asyncio.iscoroutinefunction(factory), factory
self._middlewares = list(middlewares)

self.on_response_prepare = Signal(parameters={'request', 'response'})

@property
def router(self):
return self._router
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/web_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,8 @@ def prepare(self, request):
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
yield from request.app.on_response_prepare.send(request=request,
response=self)

return self._start(request)

Expand Down
8 changes: 8 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ aiohttp.protocol module
:undoc-members:
:show-inheritance:

aiohttp.signals module
----------------------

.. automodule:: aiohttp.signals
:members:
:undoc-members:
:show-inheritance:

aiohttp.streams module
----------------------

Expand Down
10 changes: 9 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,8 @@ StreamResponse
response answers.

Send *HTTP header*. You should not change any header data after
calling this method.
calling this method, except through
:attr:`Application.on_response_start` signal callbacks.

.. deprecated:: 0.18

Expand Down Expand Up @@ -922,6 +923,13 @@ arbitrary properties for later access from

:ref:`event loop<asyncio-event-loop>` used for processing HTTP requests.

.. attribute:: on_response_start

A :class:`~aiohttp.signals.FunctionSignal` that is fired at the beginning
of :meth:`StreamResponse.start` with parameters ``request`` and
``response``. It can be used, for example, to add custom headers to each
response before sending.

.. method:: make_handler(**kwargs)

Creates HTTP protocol factory for handling requests.
Expand Down
98 changes: 98 additions & 0 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio
import unittest
from unittest import mock
from aiohttp.multidict import CIMultiDict
from aiohttp.signals import Signal
from aiohttp.web import Application
from aiohttp.web import Request, StreamResponse, Response
from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10
from aiohttp.protocol import RawRequestMessage

class TestSignals(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()

def make_request(self, method, path, headers=CIMultiDict(), app=None):
message = RawRequestMessage(method, path, HttpVersion11, headers,
False, False)
return self.request_from_message(message, app)

def request_from_message(self, message, app=None):
self.app = app if app is not None else mock.Mock()
self.payload = mock.Mock()
self.transport = mock.Mock()
self.reader = mock.Mock()
self.writer = mock.Mock()
req = Request(self.app, message, self.payload,
self.transport, self.reader, self.writer)
return req

def test_callback_valid(self):
signal = Signal(parameters={'foo', 'bar'})

# All these are suitable
good_callbacks = map(asyncio.coroutine, [
(lambda foo, bar: None),
(lambda *, foo, bar: None),
(lambda foo, bar, **kwargs: None),
(lambda foo, bar, baz=None: None),
(lambda baz=None, *, foo, bar: None),
(lambda foo=None, bar=None: None),
(lambda foo, bar=None, *, baz=None: None),
(lambda **kwargs: None),
])
for callback in good_callbacks:
signal.append(callback)

def test_callback_invalid(self):
signal = Signal(parameters={'foo', 'bar'})

# All these are unsuitable
bad_callbacks = map(asyncio.coroutine, [
(lambda foo: None),
(lambda foo, bar, baz: None),
])
for callback in bad_callbacks:
with self.assertRaises(TypeError):
signal.send(callback)

def test_add_response_prepare_signal_handler(self):
callback = asyncio.coroutine(lambda request, response: None)
app = Application(loop=self.loop)
app.on_response_prepare.append(callback)

def test_add_signal_handler_not_a_callable(self):
callback = True
app = Application(loop=self.loop)
with self.assertRaises(TypeError):
app.on_response_prepare.append(callback)

def test_function_signal_dispatch(self):
signal = Signal(parameters={'foo', 'bar'})
kwargs = {'foo': 1, 'bar': 2}

callback_mock = mock.Mock()
callback = asyncio.coroutine(callback_mock)

signal.append(callback)

self.loop.run_until_complete(signal.send(**kwargs))
callback_mock.assert_called_once_with(**kwargs)

def test_response_prepare(self):
callback = mock.Mock()

app = Application(loop=self.loop)
app.on_response_prepare.append(asyncio.coroutine(callback))

request = self.make_request('GET', '/', app=app)
response = Response(body=b'')
self.loop.run_until_complete(response.prepare(request))

callback.assert_called_once_with(request=request,
response=response)

3 changes: 2 additions & 1 deletion tests/test_web_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aiohttp.web import Request
from aiohttp.protocol import RawRequestMessage, HttpVersion11

from aiohttp import web
from aiohttp import signals, web


class TestHTTPExceptions(unittest.TestCase):
Expand All @@ -32,6 +32,7 @@ def append(self, data):

def make_request(self, method='GET', path='/', headers=CIMultiDict()):
self.app = mock.Mock()
self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'})
message = RawRequestMessage(method, path, HttpVersion11, headers,
False, False)
req = Request(self.app, message, self.payload,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_web_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import unittest
from unittest import mock
from aiohttp.signals import Signal
from aiohttp.web import Request
from aiohttp.multidict import MultiDict, CIMultiDict
from aiohttp.protocol import HttpVersion
Expand All @@ -23,6 +24,7 @@ def make_request(self, method, path, headers=CIMultiDict(), *,
if version < HttpVersion(1, 1):
closing = True
self.app = mock.Mock()
self.app.on_response_prepare = Signal(parameters={'request', 'response'})
message = RawRequestMessage(method, path, version, headers, closing,
False)
self.payload = mock.Mock()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
import unittest
from unittest import mock
from aiohttp import hdrs
from aiohttp import hdrs, signals
from aiohttp.multidict import CIMultiDict
from aiohttp.web import ContentCoding, Request, StreamResponse, Response
from aiohttp.protocol import HttpVersion, HttpVersion11, HttpVersion10
Expand All @@ -26,6 +26,7 @@ def make_request(self, method, path, headers=CIMultiDict(),

def request_from_message(self, message):
self.app = mock.Mock()
self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'})
self.payload = mock.Mock()
self.transport = mock.Mock()
self.reader = mock.Mock()
Expand Down Expand Up @@ -526,6 +527,7 @@ def tearDown(self):

def make_request(self, method, path, headers=CIMultiDict()):
self.app = mock.Mock()
self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'})
message = RawRequestMessage(method, path, HttpVersion11, headers,
False, False)
self.payload = mock.Mock()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiohttp.web import (
MsgType, Request, WebSocketResponse, HTTPMethodNotAllowed, HTTPBadRequest)
from aiohttp.protocol import RawRequestMessage, HttpVersion11
from aiohttp import errors, websocket
from aiohttp import errors, signals, websocket


class TestWebWebSocket(unittest.TestCase):
Expand Down Expand Up @@ -37,6 +37,7 @@ def make_request(self, method, path, headers=None, protocols=False):
self.reader = mock.Mock()
self.writer = mock.Mock()
self.app.loop = self.loop
self.app.on_response_prepare = signals.Signal(parameters={'request', 'response'})
req = Request(self.app, message, self.payload,
self.transport, self.reader, self.writer)
return req
Expand Down