From 46ba18cbc81a72b7a9098e39869452e92205bd95 Mon Sep 17 00:00:00 2001 From: mvdbeek Date: Wed, 9 Oct 2024 12:05:51 +0200 Subject: [PATCH] Allow CORS requests to api/workflow_landings --- lib/galaxy/webapps/base/api.py | 10 ++++ lib/galaxy/webapps/galaxy/api/__init__.py | 54 ++++++++++++++++++- lib/galaxy/webapps/galaxy/api/workflows.py | 2 +- lib/galaxy/webapps/galaxy/fast_app.py | 9 ---- lib/galaxy_test/base/populators.py | 1 + test/integration/test_web_framework_config.py | 42 +++++++++++---- 6 files changed, 98 insertions(+), 20 deletions(-) diff --git a/lib/galaxy/webapps/base/api.py b/lib/galaxy/webapps/base/api.py index 9df5b838eea2..a0faf2b11371 100644 --- a/lib/galaxy/webapps/base/api.py +++ b/lib/galaxy/webapps/base/api.py @@ -260,3 +260,13 @@ def include_all_package_routers(app: FastAPI, package_name: str): router = getattr(module, "router", None) if router: app.include_router(router, responses=responses) + + # handle CORS preflight requests - synchronize with wsgi behavior. + # this needs to happen last so it doesn't clobber routes with explicit cors handling + # it doesn't affect the CORS middleware since the middleware terminates the request handling before routing + @app.options("/api/{rest_of_path:path}") + async def preflight_handler(request: Request, rest_of_path: str) -> Response: + response = Response() + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Max-Age"] = "600" + return response diff --git a/lib/galaxy/webapps/galaxy/api/__init__.py b/lib/galaxy/webapps/galaxy/api/__init__.py index 260f90cebced..95ec6cf4069a 100644 --- a/lib/galaxy/webapps/galaxy/api/__init__.py +++ b/lib/galaxy/webapps/galaxy/api/__init__.py @@ -8,6 +8,7 @@ from typing import ( Any, AsyncGenerator, + Callable, cast, NamedTuple, Optional, @@ -379,6 +380,18 @@ def get_admin_user(trans: SessionRequestContext = DependsOnTrans): AdminUserRequired = Depends(get_admin_user) +def cors_preflight(response: Response): + response.headers["Access-Control-Allow-Origin"] = "*" + # Only allow CORS safe-listed headers for now (https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_request_header) + response.headers["Access-Control-Allow-Headers"] = "Accept,Accept-Language,Content-Language,Content-Type,Range" + response.headers["Access-Control-Max-Age"] = "600" + response.status_code = 200 + return response + + +CORSPreflightRequired = Depends(cors_preflight) + + class BaseGalaxyAPIController(BaseAPIController): def __init__(self, app: StructuredApp): super().__init__(app) @@ -401,7 +414,7 @@ class FrameworkRouter(APIRouter): def wrap_with_alias(self, verb: RestVerb, *args, alias: Optional[str] = None, **kwd): """ - Wraps FastAPI methods with additional alias keyword and require_admin handling. + Wraps FastAPI methods with additional alias keyword, require_admin and CORS handling. @router.get("/api/thing", alias="/api/deprecated_thing") will then create routes for /api/thing and /api/deprecated_thing. @@ -409,6 +422,13 @@ def wrap_with_alias(self, verb: RestVerb, *args, alias: Optional[str] = None, ** kwd = self._handle_galaxy_kwd(kwd) include_in_schema = kwd.pop("include_in_schema", True) + allow_cors = kwd.pop("allow_cors", False) + if allow_cors: + assert ( + "route_class_override" not in kwd + ), "Cannot use allow_cors=True on route and specify `route_class_override`" + kwd["route_class_override"] = APICorsRoute + def decorate_route(route, include_in_schema=include_in_schema): # Decorator solely exists to allow passing `route_class_override` to add_api_route def decorated_route(func): @@ -419,6 +439,21 @@ def decorated_route(func): include_in_schema=include_in_schema, **kwd, ) + + if allow_cors: + + dependencies = kwd.pop("dependencies", []) + dependencies.append(CORSPreflightRequired) + + self.add_api_route( + route, + endpoint=lambda: None, + methods=[RestVerb.options], + include_in_schema=False, + dependencies=dependencies, + **kwd, + ) + return func return decorated_route @@ -504,6 +539,23 @@ class Router(FrameworkRouter): user_dependency = DependsOnUser +class APICorsRoute(APIRoute): + """ + Sends CORS headers + """ + + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + response: Response = await original_route_handler(request) + response.headers["Access-Control-Allow-Origin"] = request.headers.get("Origin", "*") + response.headers["Access-Control-Max-Age"] = "600" + return response + + return custom_route_handler + + class APIContentTypeRoute(APIRoute): """ Determines endpoint to match using content-type. diff --git a/lib/galaxy/webapps/galaxy/api/workflows.py b/lib/galaxy/webapps/galaxy/api/workflows.py index 002e705c8d98..719029e818d6 100644 --- a/lib/galaxy/webapps/galaxy/api/workflows.py +++ b/lib/galaxy/webapps/galaxy/api/workflows.py @@ -1165,7 +1165,7 @@ def show_workflow( ) -> StoredWorkflowDetailed: return self.service.show_workflow(trans, workflow_id, instance, legacy, version) - @router.post("/api/workflow_landings", public=True) + @router.post("/api/workflow_landings", public=True, allow_cors=True) def create_landing( self, trans: ProvidesUserContext = DependsOnTrans, diff --git a/lib/galaxy/webapps/galaxy/fast_app.py b/lib/galaxy/webapps/galaxy/fast_app.py index 02a191efe9a0..143586156e8f 100644 --- a/lib/galaxy/webapps/galaxy/fast_app.py +++ b/lib/galaxy/webapps/galaxy/fast_app.py @@ -10,7 +10,6 @@ ) from fastapi.openapi.constants import REF_TEMPLATE from starlette.middleware.cors import CORSMiddleware -from starlette.responses import Response from galaxy.schema.generics import CustomJsonSchema from galaxy.version import VERSION @@ -121,14 +120,6 @@ async def add_x_frame_options(request: Request, call_next): allow_methods=["*"], max_age=600, ) - else: - # handle CORS preflight requests - synchronize with wsgi behavior. - @app.options("/api/{rest_of_path:path}") - async def preflight_handler(request: Request, rest_of_path: str) -> Response: - response = Response() - response.headers["Access-Control-Allow-Headers"] = "*" - response.headers["Access-Control-Max-Age"] = "600" - return response def include_legacy_openapi(app, gx_app): diff --git a/lib/galaxy_test/base/populators.py b/lib/galaxy_test/base/populators.py index 08c7e08dcdb8..459f6a6187a9 100644 --- a/lib/galaxy_test/base/populators.py +++ b/lib/galaxy_test/base/populators.py @@ -781,6 +781,7 @@ def create_workflow_landing(self, payload: CreateWorkflowLandingRequestPayload) json = payload.model_dump(mode="json") create_response = self._post(create_url, json, json=True, anon=True) api_asserts.assert_status_code_is(create_response, 200) + assert create_response.headers["access-control-allow-origin"] create_response.raise_for_status() return WorkflowLandingRequest.model_validate(create_response.json()) diff --git a/test/integration/test_web_framework_config.py b/test/integration/test_web_framework_config.py index d617eed03ae4..23e5091e3f86 100644 --- a/test/integration/test_web_framework_config.py +++ b/test/integration/test_web_framework_config.py @@ -4,10 +4,14 @@ from galaxy_test.driver import integration_util +ENDPOINT_WITH_CORS = "workflow_landings" +ENDPOINT_WITHOUT_EXPLICIT_CORS = "licenses" +WSGI_ENDPOINT = "tools" + class BaseWebFrameworkTestCase(integration_util.IntegrationTestCase): - def _options(self, headers=None): - url = self._api_url("licenses") + def _options(self, headers=None, endpoint=ENDPOINT_WITH_CORS): + url = self._api_url(endpoint) options_response = options(url, headers=headers or {}) return options_response @@ -18,8 +22,18 @@ def test_options(self): "Access-Control-Request-Method": "GET", "origin": "http://192.168.0.101:8083", } - options_response = self._options(headers) - assert options_response.status_code == 200 + options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS) + options_response.raise_for_status() + assert "access-control-allow-origin" not in options_response.headers + + def test_options_wsgi(self): + # Tests legacy handling + headers = { + "Access-Control-Request-Method": "GET", + "origin": "http://192.168.0.101:8083", + } + options_response = self._options(headers, WSGI_ENDPOINT) + options_response.raise_for_status() assert "access-control-allow-origin" not in options_response.headers def test_origin_not_allowed_default(self): @@ -28,10 +42,20 @@ def test_origin_not_allowed_default(self): "Access-Control-Request-Headers": "Authorization", "origin": "http://192.168.0.101:8083", } - options_response = self._options(headers) - assert options_response.status_code == 200 + options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS) + options_response.raise_for_status() assert "access-control-allow-origin" not in options_response.headers + def test_origin_explicitly_allowed(self): + headers = { + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "Authorization", + "Origin": "http://192.168.0.101:8083", + } + options_response = self._options(headers, ENDPOINT_WITH_CORS) + options_response.raise_for_status() + assert options_response.headers["access-control-allow-origin"] == "http://192.168.0.101:8083" + class TestAllowOriginIntegration(BaseWebFrameworkTestCase): @classmethod @@ -45,7 +69,7 @@ def test_origin_allowed_if_configured(self): "origin": "http://192.168.0.101:8083", "Access-Control-Request-Headers": "Authorization", } - options_response = self._options(headers) + options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS) options_response.raise_for_status() assert "access-control-allow-origin" in options_response.headers assert options_response.headers["access-control-allow-origin"] == "http://192.168.0.101:8083" @@ -57,7 +81,7 @@ def test_origin_allowed_if_configured_via_regex(self): "origin": "http://rna.galaxyproject.org", "Access-Control-Request-Headers": "Authorization", } - options_response = self._options(headers) + options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS) options_response.raise_for_status() assert "access-control-allow-origin" in options_response.headers assert options_response.headers["access-control-allow-origin"] == "http://rna.galaxyproject.org" @@ -69,5 +93,5 @@ def test_origin_not_allowed_if_not_in_configured_list(self): "origin": "http://192.168.0.102:8083", # swapped ip by one "Access-Control-Request-Headers": "Authorization", } - options_response = self._options(headers) + options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS) assert options_response.status_code == 400