Skip to content

Commit

Permalink
Merge pull request nirum-lang#4 from admire93/http-resource-cors
Browse files Browse the repository at this point in the history
Support CORS for @http-resource
  • Loading branch information
kanghyojun authored Dec 9, 2017
2 parents e9a6867 + 9ca3c2c commit 0fc600d
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 99 deletions.
233 changes: 143 additions & 90 deletions nirum_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,32 @@
from werkzeug.wrappers import Request, Response

__version__ = '0.2.0'
__all__ = ('AnnotationError', 'InvalidJsonError', 'ServiceMethodError',
'WsgiApp')
__all__ = ('AnnotationError', 'InvalidJsonError', 'MethodDispatch',
'PathMatch', 'ServiceMethodError', 'WsgiApp')
MethodDispatch = collections.namedtuple('MethodDispatch', [
'request', 'routed', 'service_method',
'payload', 'cors_headers'
])
PathMatch = collections.namedtuple('PathMatch', [
'match_group', 'verb', 'method_name'
])


def match_request(rules, path_info, request_method):
matched_verb = []
match = None
for _, pattern, verb, method_name in rules:
if isinstance(path_info, bytes):
# FIXME Decode properly; URI is not unicode
path_info = path_info.decode()
path_matched = pattern.match(path_info)
verb = verb.upper()
if path_matched:
matched_verb.append(verb)
if request_method == verb or request_method == 'OPTIONS':
match = PathMatch(match_group=path_matched, verb=verb,
method_name=method_name)
return match, matched_verb


def parse_json_payload(request):
Expand Down Expand Up @@ -51,6 +75,17 @@ class ServiceMethodError(LookupError):
"""Exception raised when a method is not found."""


class MethodDispatchError(ValueError):
"""Exception raised when failed to dispatch method."""

def __init__(self, request, status_code, message=None,
*args, **kwargs):
self.request = request
self.status_code = status_code
self.message = message
super(MethodDispatchError, self).__init__(*args, **kwargs)


class WsgiApp:
"""Create a WSGI application which adapts the given Nirum service.
Expand Down Expand Up @@ -137,106 +172,124 @@ def __call__(self, environ, start_response):
"""
return self.route(environ, start_response)

def route(self, environ, start_response):
"""Route
:param environ:
:param start_response:
"""
def dispatch_method(self, environ):
payload = None
request = Request(environ)
service_methods = self.service.__nirum_service_methods__
error_raised = None
for _, pattern, verb, method_name in self.rules:
path_info = environ['PATH_INFO']
if isinstance(path_info, bytes):
# FIXME Decode properly; URI is not unicode
path_info = path_info.decode()
match = pattern.match(path_info)
if match and environ['REQUEST_METHOD'] == verb.upper():
routed = True
service_method = method_name
if verb in ('GET', 'DELETE'):
method_parameters = {
k: v
for k, v in service_methods[method_name].items()
if not k.startswith('_')
}
# TODO Parsing query string
payload = {p: match.group(p) for p in method_parameters}
# CORS
cors_headers = [('Vary', 'Origin')]
request_match, matched_verb = match_request(
self.rules, request.path, request.method
)
if request_match:
service_method = request_match.method_name
cors_headers.append(
(
'Access-Control-Allow-Methods',
', '.join(matched_verb + ['OPTIONS'])
)
)
method_parameters = {
k: v
for k, v in service_methods[request_match.method_name].items()
if not k.startswith('_')
}
payload = {
p: request_match.match_group.group(p)
for p in method_parameters
}
# TODO Parsing query string
if request_match.verb not in ('GET', 'DELETE'):
try:
json_payload = parse_json_payload(request)
except InvalidJsonError as e:
raise MethodDispatchError(
request, 400,
"Invalid JSON payload: '{!s}'.".format(e)
)
else:
try:
payload = parse_json_payload(request)
except InvalidJsonError as e:
error_raised = self.error(
400, request,
message="Invalid JSON payload: '{!s}'.".format(e)
)
cors_headers = [] # TODO
break
payload.update(**json_payload)
else:
routed = False
if request.method not in ('POST', 'OPTIONS'):
error_raised = self.error(405, request)

# CORS
cors_headers = [
('Access-Control-Allow-Methods', 'POST, OPTIONS'),
('Vary', 'Origin'),
]
if self.allowed_headers:
cors_headers.append(
(
'Access-Control-Allow-Headers',
', '.join(sorted(self.allowed_headers))
)
)
try:
origin = request.headers['Origin']
except KeyError:
pass
else:
parsed_origin = urlparse.urlparse(origin)
if parsed_origin.scheme in ('http', 'https') and \
parsed_origin.hostname in self.allowed_origins:
cors_headers.append(
('Access-Control-Allow-Origin', origin)
)

if request.method == 'OPTIONS':
start_response('200 OK', cors_headers)
return []
raise MethodDispatchError(request, 405)
cors_headers.append(
('Access-Control-Allow-Methods', 'POST, OPTIONS')
)
service_method = request.args.get('method')
try:
payload = parse_json_payload(request)
except InvalidJsonError as e:
error_raised = self.error(
400, request,
message="Invalid JSON payload: '{!s}'.".format(e)
raise MethodDispatchError(
request,
400,
"Invalid JSON payload: '{!s}'.".format(e)
)
if error_raised:
response = error_raised
elif service_method:
try:
response = self.rpc(request, service_method, payload)
except ServiceMethodError:
response = self.error(
404 if routed else 400, request,
message='No service method `{}` found.'.format(
service_method
)
if self.allowed_headers:
cors_headers.append(
(
'Access-Control-Allow-Headers',
', '.join(sorted(self.allowed_headers))
)
else:
for k, v in cors_headers:
if k in response.headers:
response.headers[k] += ', ' + v # FIXME: is it proper?
else:
response.headers[k] = v
else:
response = self.error(
400, request,
message="`method` is missing."
)
try:
origin = request.headers['Origin']
except KeyError:
pass
else:
parsed_origin = urlparse.urlparse(origin)
if parsed_origin.scheme in ('http', 'https') and \
parsed_origin.hostname in self.allowed_origins:
cors_headers.append(
('Access-Control-Allow-Origin', origin)
)
return MethodDispatch(
request=request,
routed=bool(request_match),
service_method=service_method,
payload=payload,
cors_headers=cors_headers
)

def route(self, environ, start_response):
"""Route
:param environ:
:param start_response:
"""
try:
match = self.dispatch_method(environ)
except MethodDispatchError as e:
response = self.error(e.status_code, e.request, e.message)
else:
if environ['REQUEST_METHOD'] == 'OPTIONS':
start_response('200 OK', match.cors_headers)
return []
if match.service_method:
try:
response = self.rpc(
match.request, match.service_method, match.payload
)
except ServiceMethodError:
response = self.error(
404 if match.routed else 400,
match.request,
message='No service method `{}` found.'.format(
match.service_method
)
)
else:
for k, v in match.cors_headers:
if k in response.headers:
# FIXME: is it proper?
response.headers[k] += ', ' + v
else:
response.headers[k] = v
else:
response = self.error(
400, match.request,
message="`method` is missing."
)
return response(environ, start_response)

def rpc(self, request, service_method, request_json):
Expand Down
9 changes: 9 additions & 0 deletions schema-fixture/fixture.nrm
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ service satisfied-parameters-service (
text python-keyword(text from, text to),
);

service cors-verb-service (
@http-resource(method="GET", path="/foo/{foo}/")
bool get-foo(text foo),
@http-resource(method="PUT", path="/foo/{foo}/")
bool update-foo(text foo),
@http-resource(method="DELETE", path="/bar/{bar}/")
bool delete-bar(text bar),
);

unboxed token (uuid);

record complex-key-map ({point: point} value);
77 changes: 68 additions & 9 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import collections
import json

from fixture import (BadRequest, MusicService, SatisfiedParametersService,
from fixture import (BadRequest, CorsVerbService, MusicService,
SatisfiedParametersService,
Unknown, UnsatisfiedParametersService)
from pytest import fixture, mark, raises
from six import text_type
Expand Down Expand Up @@ -40,6 +41,18 @@ def raise_application_error_request(self):
raise ValueError('hello world')


class CorsVerbServiceImpl(CorsVerbService):

def get_foo(self, foo):
return True

def update_foo(self, foo):
return True

def delete_bar(self, bar):
return True


@fixture
def fx_music_wsgi():
return WsgiApp(MusicServiceImpl())
Expand Down Expand Up @@ -324,19 +337,19 @@ def test_http_resource_route(fx_test_client):
)


def split(header, lower=False):
vs = [h.strip() for h in header.split(',')]
if lower:
vs = [v.lower() for v in vs]
return frozenset(vs)


def test_cors():
app = WsgiApp(
MusicServiceImpl(),
allowed_origins=frozenset(['example.com'])
)
client = Client(app, Response)

def split(header, lower=False):
vs = map(str.strip, header.split(','))
if lower:
vs = map(str.lower, vs)
return frozenset(vs)

resp = client.options('/?method=get_music_by_artist_name', headers={
'Origin': 'https://example.com',
'Access-Control-Request-Method': 'POST',
Expand Down Expand Up @@ -372,4 +385,50 @@ def split(header, lower=False):
assert resp3.status_code == 200
allow_origin = resp3.headers.get('Access-Control-Allow-Origin', '')
assert 'disallowed.com' not in allow_origin
# TODO: URIs mapped through @http-resource also should be implemented


@mark.parametrize(
'url, allow_methods, request_method',
[
(u'/foo/abc/', {u'GET', u'PUT', u'OPTIONS'}, u'GET'),
(u'/foo/abc/', {u'GET', u'PUT', u'OPTIONS'}, u'PUT'),
(u'/bar/abc/', {u'DELETE', u'OPTIONS'}, u'DELETE'),
],
)
def test_cors_http_resouce(url, allow_methods, request_method):
app = WsgiApp(
CorsVerbServiceImpl(),
allowed_origins=frozenset(['example.com'])
)
client = Client(app, Response)
origin = u'https://example.com'
resp = client.options(url, headers={
'Origin': origin,
'Access-Control-Request-Method': request_method,
})
assert resp.status_code == 200
assert resp.headers['Access-Control-Allow-Origin'] == origin
assert split(resp.headers['Access-Control-Allow-Methods']) == allow_methods
assert u'origin' in split(resp.headers['Vary'], lower=True)
resp2 = getattr(client, request_method.lower())(
url,
headers={
'Origin': u'https://example.com',
'Access-Control-Request-Method': request_method,
'Content-Type': u'application/json',
},
)
assert resp2.status_code == 200, resp2.get_data(as_text=True)
assert resp2.headers['Access-Control-Allow-Origin'] == origin
assert allow_methods == split(
resp2.headers['Access-Control-Allow-Methods']
)
assert 'origin' in split(resp2.headers['Vary'], lower=True)

resp3 = client.options(url, headers={
'Origin': u'https://disallowed.com',
'Access-Control-Request-Method': request_method,
})
assert resp3.status_code == 200
allow_origin = resp3.headers.get('Access-Control-Allow-Origin', u'')
assert u'disallowed.com' not in allow_origin

0 comments on commit 0fc600d

Please sign in to comment.