Skip to content

Commit

Permalink
Implement CorsViewMixin (#145)
Browse files Browse the repository at this point in the history
* Implement CorsViewMixin

* fix real browser test

* Add more tests to CorsViewMixin

* Make python 3.4 compatible

* Extract preflight handler to a class
  • Loading branch information
pedrokiefer authored and asvetlov committed Dec 21, 2017
1 parent 566c48e commit 9e0c757
Show file tree
Hide file tree
Showing 13 changed files with 746 additions and 168 deletions.
42 changes: 41 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,46 @@ in the router:
for route in list(app.router.routes()):
cors.add(route)
You can also use ``CorsViewMixin`` on ``web.View``:

.. code-block:: python
class CorsView(web.View, CorsViewMixin):
cors_config = {
"*": ResourceOption(
allow_credentials=True,
allow_headers="X-Request-ID",
)
}
@asyncio.coroutine
def get(self):
return web.Response(text="Done")
@custom_cors({
"*": ResourceOption(
allow_credentials=True,
allow_headers="*",
)
})
@asyncio.coroutine
def post(self):
return web.Response(text="Done")
cors = aiohttp_cors.setup(app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
cors.add(
app.router.add_route("*", "/resource", CorsView),
webview=True)
Security
========

Expand Down Expand Up @@ -460,7 +500,7 @@ Post release steps:
Bugs
====

Please report bugs, issues, feature requests, etc. on
Please report bugs, issues, feature requests, etc. on
`GitHub <https://github.com/aio-libs/aiohttp_cors/issues>`__.


Expand Down
3 changes: 2 additions & 1 deletion aiohttp_cors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
)
from .resource_options import ResourceOptions
from .cors_config import CorsConfig
from .mixin import CorsViewMixin, custom_cors

__all__ = (
"__title__", "__version__", "__author__", "__email__", "__summary__",
"__uri__", "__license__", "__copyright__",
"setup", "CorsConfig", "ResourceOptions",
"setup", "CorsConfig", "ResourceOptions", "CorsViewMixin", "custom_cors"
)


Expand Down
5 changes: 4 additions & 1 deletion aiohttp_cors/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class AbstractRouterAdapter(metaclass=ABCMeta):
"""

@abstractmethod
def add_preflight_handler(self, routing_entity, handler):
def add_preflight_handler(self,
routing_entity,
handler,
webview: bool=False):
"""Add OPTIONS handler for all routes defined by `routing_entity`.
Does nothing if CORS handler already handles routing entity.
Expand Down
147 changes: 17 additions & 130 deletions aiohttp_cors/cors_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from .urldispatcher_router_adapter import ResourcesUrlDispatcherRouterAdapter
from .abc import AbstractRouterAdapter
from .resource_options import ResourceOptions
from .preflight_handler import _PreflightHandler

__all__ = (
"CorsConfig",
)


# Positive response to Access-Control-Allow-Credentials
_TRUE = "true"
# CORS simple response headers:
Expand Down Expand Up @@ -103,7 +103,7 @@ def _parse_config_options(
_ConfigType = Mapping[str, Union[ResourceOptions, Mapping[str, Any]]]


class _CorsConfigImpl:
class _CorsConfigImpl(_PreflightHandler):

def __init__(self,
app: web.Application,
Expand All @@ -118,7 +118,8 @@ def __init__(self,

def add(self,
routing_entity,
config: _ConfigType=None):
config: _ConfigType=None,
webview: bool=False):
"""Enable CORS for specific route or resource.
If route is passed CORS is enabled for route's resource.
Expand All @@ -133,9 +134,9 @@ def add(self,
parsed_config = _parse_config_options(config)

self._router_adapter.add_preflight_handler(
routing_entity, self._preflight_handler)
routing_entity, self._preflight_handler, webview=webview)
self._router_adapter.set_config_for_routing_entity(
routing_entity, parsed_config)
routing_entity, parsed_config, webview=webview)

return routing_entity

Expand Down Expand Up @@ -196,127 +197,12 @@ def _on_response_prepare(self,
# Set allowed credentials.
response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE

@staticmethod
def _parse_request_method(request: web.Request):
"""Parse Access-Control-Request-Method header of the preflight request
"""
method = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_METHOD)
if method is None:
raise web.HTTPForbidden(
text="CORS preflight request failed: "
"'Access-Control-Request-Method' header is not specified")

# FIXME: validate method string (ABNF: method = token), if parsing
# fails, raise HTTPForbidden.

return method

@staticmethod
def _parse_request_headers(request: web.Request):
"""Parse Access-Control-Request-Headers header or the preflight request
Returns set of headers in upper case.
"""
headers = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_HEADERS)
if headers is None:
return frozenset()

# FIXME: validate each header string, if parsing fails, raise
# HTTPForbidden.
# FIXME: check, that headers split and stripped correctly (according
# to ABNF).
headers = (h.strip(" \t").upper() for h in headers.split(","))
# pylint: disable=bad-builtin
return frozenset(filter(None, headers))

@asyncio.coroutine
def _preflight_handler(self, request: web.Request):
"""CORS preflight request handler"""

# Handle according to part 6.2 of the CORS specification.

origin = request.headers.get(hdrs.ORIGIN)
if origin is None:
# Terminate CORS according to CORS 6.2.1.
raise web.HTTPForbidden(
text="CORS preflight request failed: "
"origin header is not specified in the request")

# CORS 6.2.3. Doing it out of order is not an error.
request_method = self._parse_request_method(request)

# CORS 6.2.5. Doing it out of order is not an error.

try:
config = \
yield from self._router_adapter.get_preflight_request_config(
request, origin, request_method)
except KeyError:
raise web.HTTPForbidden(
text="CORS preflight request failed: "
"request method {!r} is not allowed "
"for {!r} origin".format(request_method, origin))

if not config:
# No allowed origins for the route.
# Terminate CORS according to CORS 6.2.1.
raise web.HTTPForbidden(
text="CORS preflight request failed: "
"no origins are allowed")

options = config.get(origin, config.get("*"))
if options is None:
# No configuration for the origin - deny.
# Terminate CORS according to CORS 6.2.2.
raise web.HTTPForbidden(
text="CORS preflight request failed: "
"origin '{}' is not allowed".format(origin))

# CORS 6.2.4
request_headers = self._parse_request_headers(request)

# CORS 6.2.6
if options.allow_headers == "*":
pass
else:
disallowed_headers = request_headers - options.allow_headers
if disallowed_headers:
raise web.HTTPForbidden(
text="CORS preflight request failed: "
"headers are not allowed: {}".format(
", ".join(disallowed_headers)))

# Ok, CORS actual request with specified in the preflight request
# parameters is allowed.
# Set appropriate headers and return 200 response.

response = web.Response()

# CORS 6.2.7
response.headers[hdrs.ACCESS_CONTROL_ALLOW_ORIGIN] = origin
if options.allow_credentials:
# Set allowed credentials.
response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE

# CORS 6.2.8
if options.max_age is not None:
response.headers[hdrs.ACCESS_CONTROL_MAX_AGE] = \
str(options.max_age)

# CORS 6.2.9
# TODO: more optimal for client preflight request cache would be to
# respond with ALL allowed methods.
response.headers[hdrs.ACCESS_CONTROL_ALLOW_METHODS] = request_method

# CORS 6.2.10
if request_headers:
# Note: case of the headers in the request is changed, but this
# shouldn't be a problem, since the headers should be compared in
# the case-insensitive way.
response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] = \
",".join(request_headers)

return response
def _get_config(self, request, origin, request_method):
config = \
yield from self._router_adapter.get_preflight_request_config(
request, origin, request_method)
return config


class CorsConfig:
Expand All @@ -341,7 +227,7 @@ def __init__(self, app: web.Application, *,
Router adapter. Required if application uses non-default router.
"""

defaults = _parse_config_options(defaults)
self.defaults = _parse_config_options(defaults)

self._cors_impl = None

Expand All @@ -355,13 +241,13 @@ def __init__(self, app: web.Application, *,

elif isinstance(app.router, web.UrlDispatcher):
self._resources_router_adapter = \
ResourcesUrlDispatcherRouterAdapter(app.router, defaults)
ResourcesUrlDispatcherRouterAdapter(app.router, self.defaults)
self._resources_cors_impl = _CorsConfigImpl(
app,
self._resources_router_adapter)
self._old_routes_cors_impl = _CorsConfigImpl(
app,
OldRoutesUrlDispatcherRouterAdapter(app.router, defaults))
OldRoutesUrlDispatcherRouterAdapter(app.router, self.defaults))
else:
raise RuntimeError(
"Router adapter is not specified. "
Expand All @@ -370,7 +256,8 @@ def __init__(self, app: web.Application, *,

def add(self,
routing_entity,
config: _ConfigType = None):
config: _ConfigType = None,
webview: bool=False):
"""Enable CORS for specific route or resource.
If route is passed CORS is enabled for route's resource.
Expand Down Expand Up @@ -404,7 +291,7 @@ def add(self,
# Route which resource has no CORS configuration, i.e.
# old-style route.
return self._old_routes_cors_impl.add(
routing_entity, config)
routing_entity, config, webview=webview)

else:
raise ValueError(
Expand Down
50 changes: 50 additions & 0 deletions aiohttp_cors/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
import collections

from aiohttp import hdrs, web
from .preflight_handler import _PreflightHandler

def custom_cors(config):
def wrapper(function):
name = "{}_cors_config".format(function.__name__)
setattr(function, name, config)
return function
return wrapper


class CorsViewMixin(_PreflightHandler):
cors_config = None

@classmethod
def get_request_config(cls, request, request_method):
try:
from . import APP_CONFIG_KEY
cors = request.app[APP_CONFIG_KEY]
except KeyError:
raise ValueError("aiohttp-cors is not configured.")

method = getattr(cls, request_method.lower(), None)

if not method:
raise KeyError()

config_property_key = "{}_cors_config".format(request_method.lower())

custom_config = getattr(method, config_property_key, None)
if not custom_config:
custom_config = {}

class_config = cls.cors_config
if not class_config:
class_config = {}

return collections.ChainMap(custom_config, class_config, cors.defaults)

@asyncio.coroutine
def _get_config(self, request, origin, request_method):
return self.get_request_config(request, request_method)

@asyncio.coroutine
def options(self):
response = yield from self._preflight_handler(self.request)
return response
Loading

0 comments on commit 9e0c757

Please sign in to comment.