Skip to content

Commit

Permalink
refactor: CORS pre-flight response generated by middleware.
Browse files Browse the repository at this point in the history
Move pre-flight response generation from the generated OPTIONS handlers to the CORS middleware.

Fixes case where pre-flight response not applied to mounted ASGI apps.

Fixes the case where a user defined OPTIONS handler may not include CORS logic. In fact, this is they only way we can be sure that we honor CORS pre-flight if the config object is provided.

Adds a stack of CORS e2e tests.
  • Loading branch information
peterschutt committed Apr 19, 2024
1 parent fb5f744 commit a30484d
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 66 deletions.
74 changes: 62 additions & 12 deletions litestar/middleware/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

from typing import TYPE_CHECKING

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from litestar.datastructures import Headers, MutableScopeHeaders
from litestar.enums import ScopeType
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.middleware.base import AbstractMiddleware

__all__ = ("CORSMiddleware",)

from litestar.response import Response
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST

if TYPE_CHECKING:
from litestar.config.cors import CORSConfig
from litestar.types import ASGIApp, Message, Receive, Scope, Send

__all__ = ("CORSMiddleware",)


class CORSMiddleware(AbstractMiddleware):
"""CORS Middleware."""
Expand All @@ -39,7 +41,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
None
"""
headers = Headers.from_scope(scope=scope)
if origin := headers.get("origin"):
origin = headers.get("origin")

if scope["type"] == ScopeType.HTTP and scope["method"] == HttpMethod.OPTIONS and origin:
request = scope["app"].request_class(scope=scope, receive=receive, send=send)
asgi_response = self._create_preflight_response(origin=origin, request_headers=headers).to_asgi_response(
app=None, request=request
)
await asgi_response(scope, receive, send)
elif origin:
await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers))
else:
await self.app(scope, receive, send)
Expand Down Expand Up @@ -68,15 +78,55 @@ async def wrapped_send(message: Message) -> None:
headers["Access-Control-Allow-Origin"] = origin
headers["Vary"] = "Origin"

# We don't want to overwrite this for preflight requests.
allow_headers = headers.get("Access-Control-Allow-Headers")
if not allow_headers and self.config.allow_headers:
headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers)))
headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers)))

allow_methods = headers.get("Access-Control-Allow-Methods")
if not allow_methods and self.config.allow_methods:
headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods)))
headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods)))

await send(message)

return wrapped_send

def _create_preflight_response(self, origin: str, request_headers: Headers) -> Response[str | None]:
pre_flight_method = request_headers.get("Access-Control-Request-Method")
failures = []

if not self.config.is_allow_all_methods and (
pre_flight_method and pre_flight_method not in self.config.allow_methods
):
failures.append("method")

response_headers = self.config.preflight_headers.copy()

if not self.config.is_origin_allowed(origin):
failures.append("Origin")
elif response_headers.get("Access-Control-Allow-Origin") != "*":
response_headers["Access-Control-Allow-Origin"] = origin

pre_flight_requested_headers = [
header.strip()
for header in request_headers.get("Access-Control-Request-Headers", "").split(",")
if header.strip()
]

if pre_flight_requested_headers:
if self.config.is_allow_all_headers:
response_headers["Access-Control-Allow-Headers"] = ", ".join(
sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore
)
elif any(header.lower() not in self.config.allow_headers for header in pre_flight_requested_headers):
failures.append("headers")

return (
Response(
content=f"Disallowed CORS {', '.join(failures)}",
status_code=HTTP_400_BAD_REQUEST,
media_type=MediaType.TEXT,
)
if failures
else Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
media_type=MediaType.TEXT,
headers=response_headers,
)
)
55 changes: 1 addition & 54 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@

from msgspec.msgpack import decode as _decode_msgpack_plain

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from litestar.datastructures.headers import Headers
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
from litestar.handlers.http_handlers import HTTPRouteHandler
from litestar.response import Response
from litestar.routes.base import BaseRoute
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.status_codes import HTTP_204_NO_CONTENT
from litestar.types.empty import Empty
from litestar.utils.scope.state import ScopeState

Expand Down Expand Up @@ -255,57 +253,6 @@ def options_handler(scope: Scope) -> Response:
Returns:
Response
"""
cors_config = scope["app"].cors_config
request_headers = Headers.from_scope(scope=scope)
origin = request_headers.get("origin")

if cors_config and origin:
pre_flight_method = request_headers.get("Access-Control-Request-Method")
failures = []

if not cors_config.is_allow_all_methods and (
pre_flight_method and pre_flight_method not in cors_config.allow_methods
):
failures.append("method")

response_headers = cors_config.preflight_headers.copy()

if not cors_config.is_origin_allowed(origin):
failures.append("Origin")
elif response_headers.get("Access-Control-Allow-Origin") != "*":
response_headers["Access-Control-Allow-Origin"] = origin

pre_flight_requested_headers = [
header.strip()
for header in request_headers.get("Access-Control-Request-Headers", "").split(",")
if header.strip()
]

if pre_flight_requested_headers:
if cors_config.is_allow_all_headers:
response_headers["Access-Control-Allow-Headers"] = ", ".join(
sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore
)
elif any(
header.lower() not in cors_config.allow_headers for header in pre_flight_requested_headers
):
failures.append("headers")

return (
Response(
content=f"Disallowed CORS {', '.join(failures)}",
status_code=HTTP_400_BAD_REQUEST,
media_type=MediaType.TEXT,
)
if failures
else Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
media_type=MediaType.TEXT,
headers=response_headers,
)
)

return Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
Expand Down
Empty file added tests/e2e/test_cors/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/e2e/test_cors/test_cors_allowed_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from litestar import Litestar, get
from litestar.config.cors import CORSConfig
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.testing import TestClient


@get("/headers-test")
async def headers_handler() -> str:
return "Test Successful!"


cors_config = CORSConfig(
allow_methods=["GET"],
allow_origins=["https://allowed-origin.com"],
allow_headers=["X-Custom-Header", "Content-Type"],
)
app = Litestar(route_handlers=[headers_handler], cors_config=cors_config)


def test_cors_with_specific_allowed_headers() -> None:
with TestClient(app) as client:
response = client.options(
"/endpoint",
headers={
"Origin": "https://allowed-origin.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Custom-Header, Content-Type",
},
)
assert response.status_code == HTTP_204_NO_CONTENT
assert "x-custom-header" in response.headers["access-control-allow-headers"]
assert "content-type" in response.headers["access-control-allow-headers"]


def test_cors_with_unauthorized_headers() -> None:
with TestClient(app) as client:
response = client.options(
"/endpoint",
headers={
"Origin": "https://allowed-origin.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Not-Allowed-Header",
},
)
assert response.status_code == HTTP_400_BAD_REQUEST
assert (
"access-control-allow-headers" not in response.headers
or "x-not-allowed-header" not in response.headers.get("access-control-allow-headers", "")
)
39 changes: 39 additions & 0 deletions tests/e2e/test_cors/test_cors_allowed_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from http import HTTPStatus

from litestar import Litestar, get
from litestar.config.cors import CORSConfig
from litestar.testing import TestClient


@get("/method-test")
async def method_handler() -> str:
return "Method Test Successful!"


cors_config = CORSConfig(allow_methods=["GET", "POST"], allow_origins=["https://allowed-origin.com"])
app = Litestar(route_handlers=[method_handler], cors_config=cors_config)


def test_cors_allowed_methods() -> None:
with TestClient(app) as client:
response = client.options(
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"}
)
assert response.status_code == HTTPStatus.NO_CONTENT
assert response.headers["access-control-allow-origin"] == "https://allowed-origin.com"
assert "GET" in response.headers["access-control-allow-methods"]

response = client.options(
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "POST"}
)
assert response.status_code == HTTPStatus.NO_CONTENT
assert "POST" in response.headers["access-control-allow-methods"]


def test_cors_disallowed_methods() -> None:
with TestClient(app) as client:
response = client.options(
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "PUT"}
)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert "PUT" not in response.headers.get("access-control-allow-methods", "")
39 changes: 39 additions & 0 deletions tests/e2e/test_cors/test_cors_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from litestar import Litestar, get
from litestar.config.cors import CORSConfig
from litestar.status_codes import HTTP_204_NO_CONTENT
from litestar.testing import TestClient


@get("/credentials-test")
async def credentials_handler() -> str:
return "Test Successful!"


def test_cors_with_credentials_allowed() -> None:
cors_config = CORSConfig(
allow_methods=["GET"], allow_origins=["https://allowed-origin.com"], allow_credentials=True
)
app = Litestar(route_handlers=[credentials_handler], cors_config=cors_config)

with TestClient(app) as client:
response = client.options(
"/endpoint", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"}
)
assert response.status_code == HTTP_204_NO_CONTENT
assert response.headers["access-control-allow-credentials"] == "true"


def test_cors_with_credentials_disallowed() -> None:
cors_config = CORSConfig(
allow_methods=["GET"],
allow_origins=["https://allowed-origin.com"],
allow_credentials=False, # Credentials should not be allowed
)
app = Litestar(route_handlers=[credentials_handler], cors_config=cors_config)

with TestClient(app) as client:
response = client.options(
"/endpoint", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"}
)
assert response.status_code == HTTP_204_NO_CONTENT
assert "access-control-allow-credentials" not in response.headers
30 changes: 30 additions & 0 deletions tests/e2e/test_cors/test_cors_for_middleware_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from http import HTTPStatus

from litestar import Litestar, get
from litestar.config.cors import CORSConfig
from litestar.exceptions import HTTPException
from litestar.middleware import AbstractMiddleware
from litestar.testing import TestClient
from litestar.types.asgi_types import Receive, Scope, Send


class ExceptionMiddleware(AbstractMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Intentional Error")


@get("/test")
async def handler() -> str:
return "Should not reach this"


cors_config = CORSConfig(allow_methods=["GET"], allow_origins=["https://allowed-origin.com"], allow_credentials=True)
app = Litestar(route_handlers=[handler], cors_config=cors_config, middleware=[ExceptionMiddleware])


def test_cors_on_middleware_exception() -> None:
with TestClient(app) as client:
response = client.get("/test", headers={"Origin": "https://allowed-origin.com"})
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
assert response.headers["access-control-allow-origin"] == "https://allowed-origin.com"
assert response.headers["access-control-allow-credentials"] == "true"
Loading

0 comments on commit a30484d

Please sign in to comment.