Skip to content

Commit

Permalink
Allow CORS requests to api/workflow_landings
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdbeek committed Oct 10, 2024
1 parent 4c75803 commit 46ba18c
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 20 deletions.
10 changes: 10 additions & 0 deletions lib/galaxy/webapps/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,13 @@ def include_all_package_routers(app: FastAPI, package_name: str):
router = getattr(module, "router", None)
if router:
app.include_router(router, responses=responses)

# handle CORS preflight requests - synchronize with wsgi behavior.
# this needs to happen last so it doesn't clobber routes with explicit cors handling
# it doesn't affect the CORS middleware since the middleware terminates the request handling before routing
@app.options("/api/{rest_of_path:path}")
async def preflight_handler(request: Request, rest_of_path: str) -> Response:
response = Response()
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Max-Age"] = "600"
return response
54 changes: 53 additions & 1 deletion lib/galaxy/webapps/galaxy/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
Any,
AsyncGenerator,
Callable,
cast,
NamedTuple,
Optional,
Expand Down Expand Up @@ -379,6 +380,18 @@ def get_admin_user(trans: SessionRequestContext = DependsOnTrans):
AdminUserRequired = Depends(get_admin_user)


def cors_preflight(response: Response):
response.headers["Access-Control-Allow-Origin"] = "*"
# Only allow CORS safe-listed headers for now (https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_request_header)
response.headers["Access-Control-Allow-Headers"] = "Accept,Accept-Language,Content-Language,Content-Type,Range"
response.headers["Access-Control-Max-Age"] = "600"
response.status_code = 200
return response


CORSPreflightRequired = Depends(cors_preflight)


class BaseGalaxyAPIController(BaseAPIController):
def __init__(self, app: StructuredApp):
super().__init__(app)
Expand All @@ -401,14 +414,21 @@ class FrameworkRouter(APIRouter):

def wrap_with_alias(self, verb: RestVerb, *args, alias: Optional[str] = None, **kwd):
"""
Wraps FastAPI methods with additional alias keyword and require_admin handling.
Wraps FastAPI methods with additional alias keyword, require_admin and CORS handling.
@router.get("/api/thing", alias="/api/deprecated_thing") will then create
routes for /api/thing and /api/deprecated_thing.
"""
kwd = self._handle_galaxy_kwd(kwd)
include_in_schema = kwd.pop("include_in_schema", True)

allow_cors = kwd.pop("allow_cors", False)
if allow_cors:
assert (
"route_class_override" not in kwd
), "Cannot use allow_cors=True on route and specify `route_class_override`"
kwd["route_class_override"] = APICorsRoute

def decorate_route(route, include_in_schema=include_in_schema):
# Decorator solely exists to allow passing `route_class_override` to add_api_route
def decorated_route(func):
Expand All @@ -419,6 +439,21 @@ def decorated_route(func):
include_in_schema=include_in_schema,
**kwd,
)

if allow_cors:

dependencies = kwd.pop("dependencies", [])
dependencies.append(CORSPreflightRequired)

self.add_api_route(
route,
endpoint=lambda: None,
methods=[RestVerb.options],
include_in_schema=False,
dependencies=dependencies,
**kwd,
)

return func

return decorated_route
Expand Down Expand Up @@ -504,6 +539,23 @@ class Router(FrameworkRouter):
user_dependency = DependsOnUser


class APICorsRoute(APIRoute):
"""
Sends CORS headers
"""

def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
response: Response = await original_route_handler(request)
response.headers["Access-Control-Allow-Origin"] = request.headers.get("Origin", "*")
response.headers["Access-Control-Max-Age"] = "600"
return response

return custom_route_handler


class APIContentTypeRoute(APIRoute):
"""
Determines endpoint to match using content-type.
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/galaxy/api/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ def show_workflow(
) -> StoredWorkflowDetailed:
return self.service.show_workflow(trans, workflow_id, instance, legacy, version)

@router.post("/api/workflow_landings", public=True)
@router.post("/api/workflow_landings", public=True, allow_cors=True)
def create_landing(
self,
trans: ProvidesUserContext = DependsOnTrans,
Expand Down
9 changes: 0 additions & 9 deletions lib/galaxy/webapps/galaxy/fast_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from fastapi.openapi.constants import REF_TEMPLATE
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import Response

from galaxy.schema.generics import CustomJsonSchema
from galaxy.version import VERSION
Expand Down Expand Up @@ -121,14 +120,6 @@ async def add_x_frame_options(request: Request, call_next):
allow_methods=["*"],
max_age=600,
)
else:
# handle CORS preflight requests - synchronize with wsgi behavior.
@app.options("/api/{rest_of_path:path}")
async def preflight_handler(request: Request, rest_of_path: str) -> Response:
response = Response()
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Max-Age"] = "600"
return response


def include_legacy_openapi(app, gx_app):
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy_test/base/populators.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def create_workflow_landing(self, payload: CreateWorkflowLandingRequestPayload)
json = payload.model_dump(mode="json")
create_response = self._post(create_url, json, json=True, anon=True)
api_asserts.assert_status_code_is(create_response, 200)
assert create_response.headers["access-control-allow-origin"]
create_response.raise_for_status()
return WorkflowLandingRequest.model_validate(create_response.json())

Expand Down
42 changes: 33 additions & 9 deletions test/integration/test_web_framework_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

from galaxy_test.driver import integration_util

ENDPOINT_WITH_CORS = "workflow_landings"
ENDPOINT_WITHOUT_EXPLICIT_CORS = "licenses"
WSGI_ENDPOINT = "tools"


class BaseWebFrameworkTestCase(integration_util.IntegrationTestCase):
def _options(self, headers=None):
url = self._api_url("licenses")
def _options(self, headers=None, endpoint=ENDPOINT_WITH_CORS):
url = self._api_url(endpoint)
options_response = options(url, headers=headers or {})
return options_response

Expand All @@ -18,8 +22,18 @@ def test_options(self):
"Access-Control-Request-Method": "GET",
"origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers)
assert options_response.status_code == 200
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" not in options_response.headers

def test_options_wsgi(self):
# Tests legacy handling
headers = {
"Access-Control-Request-Method": "GET",
"origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers, WSGI_ENDPOINT)
options_response.raise_for_status()
assert "access-control-allow-origin" not in options_response.headers

def test_origin_not_allowed_default(self):
Expand All @@ -28,10 +42,20 @@ def test_origin_not_allowed_default(self):
"Access-Control-Request-Headers": "Authorization",
"origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers)
assert options_response.status_code == 200
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" not in options_response.headers

def test_origin_explicitly_allowed(self):
headers = {
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "Authorization",
"Origin": "http://192.168.0.101:8083",
}
options_response = self._options(headers, ENDPOINT_WITH_CORS)
options_response.raise_for_status()
assert options_response.headers["access-control-allow-origin"] == "http://192.168.0.101:8083"


class TestAllowOriginIntegration(BaseWebFrameworkTestCase):
@classmethod
Expand All @@ -45,7 +69,7 @@ def test_origin_allowed_if_configured(self):
"origin": "http://192.168.0.101:8083",
"Access-Control-Request-Headers": "Authorization",
}
options_response = self._options(headers)
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" in options_response.headers
assert options_response.headers["access-control-allow-origin"] == "http://192.168.0.101:8083"
Expand All @@ -57,7 +81,7 @@ def test_origin_allowed_if_configured_via_regex(self):
"origin": "http://rna.galaxyproject.org",
"Access-Control-Request-Headers": "Authorization",
}
options_response = self._options(headers)
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
options_response.raise_for_status()
assert "access-control-allow-origin" in options_response.headers
assert options_response.headers["access-control-allow-origin"] == "http://rna.galaxyproject.org"
Expand All @@ -69,5 +93,5 @@ def test_origin_not_allowed_if_not_in_configured_list(self):
"origin": "http://192.168.0.102:8083", # swapped ip by one
"Access-Control-Request-Headers": "Authorization",
}
options_response = self._options(headers)
options_response = self._options(headers, ENDPOINT_WITHOUT_EXPLICIT_CORS)
assert options_response.status_code == 400

0 comments on commit 46ba18c

Please sign in to comment.