From 9ca3c2c5979036601ff8c4182b6df556f51c65dc Mon Sep 17 00:00:00 2001 From: Kang Hyojun Date: Sat, 9 Dec 2017 00:56:32 +0900 Subject: [PATCH] Support CORS for @http-resource [changelog skip] --- nirum_wsgi.py | 233 +++++++++++++++++++++++-------------- schema-fixture/fixture.nrm | 9 ++ tests.py | 77 ++++++++++-- 3 files changed, 220 insertions(+), 99 deletions(-) diff --git a/nirum_wsgi.py b/nirum_wsgi.py index e083934..0e5c539 100644 --- a/nirum_wsgi.py +++ b/nirum_wsgi.py @@ -22,8 +22,32 @@ from werkzeug.wrappers import Request, Response __version__ = '0.2.0' -__all__ = ('AnnotationError', 'InvalidJsonError', 'ServiceMethodError', - 'WsgiApp') +__all__ = ('AnnotationError', 'InvalidJsonError', 'MethodDispatch', + 'PathMatch', 'ServiceMethodError', 'WsgiApp') +MethodDispatch = collections.namedtuple('MethodDispatch', [ + 'request', 'routed', 'service_method', + 'payload', 'cors_headers' +]) +PathMatch = collections.namedtuple('PathMatch', [ + 'match_group', 'verb', 'method_name' +]) + + +def match_request(rules, path_info, request_method): + matched_verb = [] + match = None + for _, pattern, verb, method_name in rules: + if isinstance(path_info, bytes): + # FIXME Decode properly; URI is not unicode + path_info = path_info.decode() + path_matched = pattern.match(path_info) + verb = verb.upper() + if path_matched: + matched_verb.append(verb) + if request_method == verb or request_method == 'OPTIONS': + match = PathMatch(match_group=path_matched, verb=verb, + method_name=method_name) + return match, matched_verb def parse_json_payload(request): @@ -51,6 +75,17 @@ class ServiceMethodError(LookupError): """Exception raised when a method is not found.""" +class MethodDispatchError(ValueError): + """Exception raised when failed to dispatch method.""" + + def __init__(self, request, status_code, message=None, + *args, **kwargs): + self.request = request + self.status_code = status_code + self.message = message + super(MethodDispatchError, self).__init__(*args, **kwargs) + + class WsgiApp: """Create a WSGI application which adapts the given Nirum service. @@ -137,106 +172,124 @@ def __call__(self, environ, start_response): """ return self.route(environ, start_response) - def route(self, environ, start_response): - """Route - - :param environ: - :param start_response: - - """ + def dispatch_method(self, environ): + payload = None request = Request(environ) service_methods = self.service.__nirum_service_methods__ - error_raised = None - for _, pattern, verb, method_name in self.rules: - path_info = environ['PATH_INFO'] - if isinstance(path_info, bytes): - # FIXME Decode properly; URI is not unicode - path_info = path_info.decode() - match = pattern.match(path_info) - if match and environ['REQUEST_METHOD'] == verb.upper(): - routed = True - service_method = method_name - if verb in ('GET', 'DELETE'): - method_parameters = { - k: v - for k, v in service_methods[method_name].items() - if not k.startswith('_') - } - # TODO Parsing query string - payload = {p: match.group(p) for p in method_parameters} + # CORS + cors_headers = [('Vary', 'Origin')] + request_match, matched_verb = match_request( + self.rules, request.path, request.method + ) + if request_match: + service_method = request_match.method_name + cors_headers.append( + ( + 'Access-Control-Allow-Methods', + ', '.join(matched_verb + ['OPTIONS']) + ) + ) + method_parameters = { + k: v + for k, v in service_methods[request_match.method_name].items() + if not k.startswith('_') + } + payload = { + p: request_match.match_group.group(p) + for p in method_parameters + } + # TODO Parsing query string + if request_match.verb not in ('GET', 'DELETE'): + try: + json_payload = parse_json_payload(request) + except InvalidJsonError as e: + raise MethodDispatchError( + request, 400, + "Invalid JSON payload: '{!s}'.".format(e) + ) else: - try: - payload = parse_json_payload(request) - except InvalidJsonError as e: - error_raised = self.error( - 400, request, - message="Invalid JSON payload: '{!s}'.".format(e) - ) - cors_headers = [] # TODO - break + payload.update(**json_payload) else: - routed = False if request.method not in ('POST', 'OPTIONS'): - error_raised = self.error(405, request) - - # CORS - cors_headers = [ - ('Access-Control-Allow-Methods', 'POST, OPTIONS'), - ('Vary', 'Origin'), - ] - if self.allowed_headers: - cors_headers.append( - ( - 'Access-Control-Allow-Headers', - ', '.join(sorted(self.allowed_headers)) - ) - ) - try: - origin = request.headers['Origin'] - except KeyError: - pass - else: - parsed_origin = urlparse.urlparse(origin) - if parsed_origin.scheme in ('http', 'https') and \ - parsed_origin.hostname in self.allowed_origins: - cors_headers.append( - ('Access-Control-Allow-Origin', origin) - ) - - if request.method == 'OPTIONS': - start_response('200 OK', cors_headers) - return [] + raise MethodDispatchError(request, 405) + cors_headers.append( + ('Access-Control-Allow-Methods', 'POST, OPTIONS') + ) service_method = request.args.get('method') try: payload = parse_json_payload(request) except InvalidJsonError as e: - error_raised = self.error( - 400, request, - message="Invalid JSON payload: '{!s}'.".format(e) + raise MethodDispatchError( + request, + 400, + "Invalid JSON payload: '{!s}'.".format(e) ) - if error_raised: - response = error_raised - elif service_method: - try: - response = self.rpc(request, service_method, payload) - except ServiceMethodError: - response = self.error( - 404 if routed else 400, request, - message='No service method `{}` found.'.format( - service_method - ) + if self.allowed_headers: + cors_headers.append( + ( + 'Access-Control-Allow-Headers', + ', '.join(sorted(self.allowed_headers)) ) - else: - for k, v in cors_headers: - if k in response.headers: - response.headers[k] += ', ' + v # FIXME: is it proper? - else: - response.headers[k] = v - else: - response = self.error( - 400, request, - message="`method` is missing." ) + try: + origin = request.headers['Origin'] + except KeyError: + pass + else: + parsed_origin = urlparse.urlparse(origin) + if parsed_origin.scheme in ('http', 'https') and \ + parsed_origin.hostname in self.allowed_origins: + cors_headers.append( + ('Access-Control-Allow-Origin', origin) + ) + return MethodDispatch( + request=request, + routed=bool(request_match), + service_method=service_method, + payload=payload, + cors_headers=cors_headers + ) + + def route(self, environ, start_response): + """Route + + :param environ: + :param start_response: + + """ + try: + match = self.dispatch_method(environ) + except MethodDispatchError as e: + response = self.error(e.status_code, e.request, e.message) + else: + if environ['REQUEST_METHOD'] == 'OPTIONS': + start_response('200 OK', match.cors_headers) + return [] + if match.service_method: + try: + response = self.rpc( + match.request, match.service_method, match.payload + ) + except ServiceMethodError: + response = self.error( + 404 if match.routed else 400, + match.request, + message='No service method `{}` found.'.format( + match.service_method + ) + ) + else: + for k, v in match.cors_headers: + if k in response.headers: + # FIXME: is it proper? + response.headers[k] += ', ' + v + else: + response.headers[k] = v + else: + response = self.error( + 400, match.request, + message="`method` is missing." + ) return response(environ, start_response) def rpc(self, request, service_method, request_json): diff --git a/schema-fixture/fixture.nrm b/schema-fixture/fixture.nrm index 87fbb18..903efb6 100644 --- a/schema-fixture/fixture.nrm +++ b/schema-fixture/fixture.nrm @@ -43,6 +43,15 @@ service satisfied-parameters-service ( text python-keyword(text from, text to), ); +service cors-verb-service ( + @http-resource(method="GET", path="/foo/{foo}/") + bool get-foo(text foo), + @http-resource(method="PUT", path="/foo/{foo}/") + bool update-foo(text foo), + @http-resource(method="DELETE", path="/bar/{bar}/") + bool delete-bar(text bar), +); + unboxed token (uuid); record complex-key-map ({point: point} value); diff --git a/tests.py b/tests.py index 88801d1..5e97a58 100644 --- a/tests.py +++ b/tests.py @@ -1,7 +1,8 @@ import collections import json -from fixture import (BadRequest, MusicService, SatisfiedParametersService, +from fixture import (BadRequest, CorsVerbService, MusicService, + SatisfiedParametersService, Unknown, UnsatisfiedParametersService) from pytest import fixture, mark, raises from six import text_type @@ -40,6 +41,18 @@ def raise_application_error_request(self): raise ValueError('hello world') +class CorsVerbServiceImpl(CorsVerbService): + + def get_foo(self, foo): + return True + + def update_foo(self, foo): + return True + + def delete_bar(self, bar): + return True + + @fixture def fx_music_wsgi(): return WsgiApp(MusicServiceImpl()) @@ -324,19 +337,19 @@ def test_http_resource_route(fx_test_client): ) +def split(header, lower=False): + vs = [h.strip() for h in header.split(',')] + if lower: + vs = [v.lower() for v in vs] + return frozenset(vs) + + def test_cors(): app = WsgiApp( MusicServiceImpl(), allowed_origins=frozenset(['example.com']) ) client = Client(app, Response) - - def split(header, lower=False): - vs = map(str.strip, header.split(',')) - if lower: - vs = map(str.lower, vs) - return frozenset(vs) - resp = client.options('/?method=get_music_by_artist_name', headers={ 'Origin': 'https://example.com', 'Access-Control-Request-Method': 'POST', @@ -372,4 +385,50 @@ def split(header, lower=False): assert resp3.status_code == 200 allow_origin = resp3.headers.get('Access-Control-Allow-Origin', '') assert 'disallowed.com' not in allow_origin - # TODO: URIs mapped through @http-resource also should be implemented + + +@mark.parametrize( + 'url, allow_methods, request_method', + [ + (u'/foo/abc/', {u'GET', u'PUT', u'OPTIONS'}, u'GET'), + (u'/foo/abc/', {u'GET', u'PUT', u'OPTIONS'}, u'PUT'), + (u'/bar/abc/', {u'DELETE', u'OPTIONS'}, u'DELETE'), + ], +) +def test_cors_http_resouce(url, allow_methods, request_method): + app = WsgiApp( + CorsVerbServiceImpl(), + allowed_origins=frozenset(['example.com']) + ) + client = Client(app, Response) + origin = u'https://example.com' + resp = client.options(url, headers={ + 'Origin': origin, + 'Access-Control-Request-Method': request_method, + }) + assert resp.status_code == 200 + assert resp.headers['Access-Control-Allow-Origin'] == origin + assert split(resp.headers['Access-Control-Allow-Methods']) == allow_methods + assert u'origin' in split(resp.headers['Vary'], lower=True) + resp2 = getattr(client, request_method.lower())( + url, + headers={ + 'Origin': u'https://example.com', + 'Access-Control-Request-Method': request_method, + 'Content-Type': u'application/json', + }, + ) + assert resp2.status_code == 200, resp2.get_data(as_text=True) + assert resp2.headers['Access-Control-Allow-Origin'] == origin + assert allow_methods == split( + resp2.headers['Access-Control-Allow-Methods'] + ) + assert 'origin' in split(resp2.headers['Vary'], lower=True) + + resp3 = client.options(url, headers={ + 'Origin': u'https://disallowed.com', + 'Access-Control-Request-Method': request_method, + }) + assert resp3.status_code == 200 + allow_origin = resp3.headers.get('Access-Control-Allow-Origin', u'') + assert u'disallowed.com' not in allow_origin