diff --git a/aiohttp/abc.py b/aiohttp/abc.py index fb426f14421..cce10724e09 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -14,10 +14,9 @@ def __init__(self): def post_init(self, app): """Post init stage. - It's not an abstract method for sake of backward compatibility - but if router wans to be aware about application it should - override it. - + Not an abstract method for sake of backward compatibility, + but if the router wants to be aware of the application + it can override this. """ @property diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 1e9bb4e9e2a..a6f65f49768 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -717,6 +717,7 @@ def __init__(self): super().__init__() self._resources = [] self._named_resources = {} + self._default_allow_head = False @asyncio.coroutine def resolve(self, request): @@ -757,6 +758,9 @@ def routes(self): def named_resources(self): return MappingProxyType(self._named_resources) + def set_defaults(self, *, allow_head): + self._default_allow_head = allow_head + def register_resource(self, resource): assert isinstance(resource, AbstractResource), \ 'Instance of AbstractResource class is required, got {!r}'.format( @@ -857,11 +861,16 @@ def add_head(self, *args, **kwargs): """ return self.add_route(hdrs.METH_HEAD, *args, **kwargs) - def add_get(self, *args, **kwargs): + def add_get(self, *args, allow_head=None, name=None, **kwargs): """ - Shortcut for add_route with method GET + Shortcut for add_route with method GET, if allow_head is true another + route is added allowing head requests to the same endpoint """ - return self.add_route(hdrs.METH_GET, *args, **kwargs) + if allow_head or (allow_head is None and self._default_allow_head): + # the head route can't have "name" set or it would conflict with + # the GET route below + self.add_route(hdrs.METH_HEAD, *args, **kwargs) + return self.add_route(hdrs.METH_GET, *args, name=name, **kwargs) def add_post(self, *args, **kwargs): """ diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py index 3cd7ecd10e0..4833749bec4 100644 --- a/tests/test_web_urldispatcher.py +++ b/tests/test_web_urldispatcher.py @@ -310,3 +310,63 @@ def resolve(self, request): resp = yield from client.get('/') assert resp.status == 412 + + +@asyncio.coroutine +def test_allow_head(loop, test_client): + """ + Test allow_head on routes. + """ + app = web.Application(loop=loop) + + def handler(_): + return web.Response() + app.router.add_get('/a', handler, allow_head=True, name='a') + app.router.add_get('/b', handler, allow_head=False, name='b') + client = yield from test_client(app) + + r = yield from client.get('/a') + assert r.status == 200 + yield from r.release() + + r = yield from client.head('/a') + assert r.status == 200 + yield from r.release() + + r = yield from client.get('/b') + assert r.status == 200 + yield from r.release() + + r = yield from client.head('/b') + assert r.status == 405 + yield from r.release() + + +@pytest.mark.parametrize('dft_allow_head,allow_head,status', [ + (True, False, 405), + (True, True, 200), + (True, None, 200), + (False, False, 405), + (False, True, 200), + (False, None, 405), +]) +@asyncio.coroutine +def test_allow_head(loop, test_client, dft_allow_head, allow_head, status): + """ + Test allow_head on routes. + """ + app = web.Application(loop=loop) + + def handler(_): + return web.Response() + app.router.set_defaults(allow_head=dft_allow_head) + app.router.add_get('/', handler, allow_head=allow_head) + client = yield from test_client(app) + + r = yield from client.get('/') + assert r.status == 200 + yield from r.release() + + r = yield from client.head('/') + assert r.status == status + yield from r.release()