Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: CORS handling #3395

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,15 +839,16 @@ def _create_asgi_handler(self) -> ASGIApp:

If CORS or TrustedHost configs are provided to the constructor, they will wrap the router as well.
"""
asgi_handler: ASGIApp = self.asgi_router
if self.cors_config:
asgi_handler = CORSMiddleware(app=asgi_handler, config=self.cors_config)

return wrap_in_exception_handler(
app=asgi_handler,
asgi_handler = wrap_in_exception_handler(
app=self.asgi_router,
exception_handlers=self.exception_handlers or {}, # pyright: ignore
)

if self.cors_config:
return CORSMiddleware(app=asgi_handler, config=self.cors_config)

return asgi_handler

def _wrap_send(self, send: Send, scope: Scope) -> Send:
"""Wrap the ASGI send and handles any 'before send' hooks.

Expand Down
74 changes: 62 additions & 12 deletions litestar/middleware/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

from typing import TYPE_CHECKING

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from litestar.datastructures import Headers, MutableScopeHeaders
from litestar.enums import ScopeType
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.middleware.base import AbstractMiddleware

__all__ = ("CORSMiddleware",)

from litestar.response import Response
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST

if TYPE_CHECKING:
from litestar.config.cors import CORSConfig
from litestar.types import ASGIApp, Message, Receive, Scope, Send

__all__ = ("CORSMiddleware",)


class CORSMiddleware(AbstractMiddleware):
"""CORS Middleware."""
Expand All @@ -39,7 +41,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
None
"""
headers = Headers.from_scope(scope=scope)
if origin := headers.get("origin"):
origin = headers.get("origin")

if scope["type"] == ScopeType.HTTP and scope["method"] == HttpMethod.OPTIONS and origin:
request = scope["app"].request_class(scope=scope, receive=receive, send=send)
asgi_response = self._create_preflight_response(origin=origin, request_headers=headers).to_asgi_response(
app=None, request=request
)
await asgi_response(scope, receive, send)
elif origin:
await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers))
else:
await self.app(scope, receive, send)
Expand Down Expand Up @@ -68,15 +78,55 @@ async def wrapped_send(message: Message) -> None:
headers["Access-Control-Allow-Origin"] = origin
headers["Vary"] = "Origin"

# We don't want to overwrite this for preflight requests.
allow_headers = headers.get("Access-Control-Allow-Headers")
if not allow_headers and self.config.allow_headers:
headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers)))
headers["Access-Control-Allow-Headers"] = ", ".join(sorted(set(self.config.allow_headers)))

allow_methods = headers.get("Access-Control-Allow-Methods")
if not allow_methods and self.config.allow_methods:
headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods)))
headers["Access-Control-Allow-Methods"] = ", ".join(sorted(set(self.config.allow_methods)))

await send(message)

return wrapped_send

def _create_preflight_response(self, origin: str, request_headers: Headers) -> Response[str | None]:
pre_flight_method = request_headers.get("Access-Control-Request-Method")
failures = []

if not self.config.is_allow_all_methods and (
pre_flight_method and pre_flight_method not in self.config.allow_methods
):
failures.append("method")

response_headers = self.config.preflight_headers.copy()

if not self.config.is_origin_allowed(origin):
failures.append("Origin")
elif response_headers.get("Access-Control-Allow-Origin") != "*":
response_headers["Access-Control-Allow-Origin"] = origin

pre_flight_requested_headers = [
header.strip()
for header in request_headers.get("Access-Control-Request-Headers", "").split(",")
if header.strip()
]

if pre_flight_requested_headers:
if self.config.is_allow_all_headers:
response_headers["Access-Control-Allow-Headers"] = ", ".join(
sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore
)
elif any(header.lower() not in self.config.allow_headers for header in pre_flight_requested_headers):
failures.append("headers")

return (
Response(
content=f"Disallowed CORS {', '.join(failures)}",
status_code=HTTP_400_BAD_REQUEST,
media_type=MediaType.TEXT,
)
if failures
else Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
media_type=MediaType.TEXT,
headers=response_headers,
)
)
12 changes: 2 additions & 10 deletions litestar/middleware/exceptions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@
from traceback import format_exception
from typing import TYPE_CHECKING, Any, Type, cast

from litestar.datastructures import Headers
from litestar.enums import MediaType, ScopeType
from litestar.exceptions import HTTPException, LitestarException, WebSocketException
from litestar.middleware.cors import CORSMiddleware
from litestar.middleware.exceptions._debug_response import _get_type_encoders_for_request, create_debug_response
from litestar.serialization import encode_json, get_serializer
from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
from litestar.utils.deprecation import warn_deprecation

__all__ = ("ExceptionHandlerMiddleware", "ExceptionResponseContent", "create_exception_response")


if TYPE_CHECKING:
from starlette.exceptions import HTTPException as StarletteHTTPException

Expand All @@ -37,6 +32,8 @@
)
from litestar.types.asgi_types import WebSocketCloseEvent

__all__ = ("ExceptionHandlerMiddleware", "ExceptionResponseContent", "create_exception_response")


def get_exception_handler(exception_handlers: ExceptionHandlersMap, exc: Exception) -> ExceptionHandler | None:
"""Given a dictionary that maps exceptions and status codes to handler functions, and an exception, returns the
Expand Down Expand Up @@ -252,11 +249,6 @@ async def handle_request_exception(
None.
"""

headers = Headers.from_scope(scope=scope)
if litestar_app.cors_config and (origin := headers.get("origin")):
cors_middleware = CORSMiddleware(app=self.app, config=litestar_app.cors_config)
send = cors_middleware.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers)

exception_handler = get_exception_handler(self.exception_handlers, exc) or self.default_http_exception_handler
request: Request[Any, Any, Any] = litestar_app.request_class(scope=scope, receive=receive, send=send)
response = exception_handler(request, exc)
Expand Down
55 changes: 1 addition & 54 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@

from msgspec.msgpack import decode as _decode_msgpack_plain

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from litestar.datastructures.headers import Headers
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
from litestar.handlers.http_handlers import HTTPRouteHandler
from litestar.response import Response
from litestar.routes.base import BaseRoute
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.status_codes import HTTP_204_NO_CONTENT
from litestar.types.empty import Empty
from litestar.utils.scope.state import ScopeState

Expand Down Expand Up @@ -255,57 +253,6 @@ def options_handler(scope: Scope) -> Response:
Returns:
Response
"""
cors_config = scope["app"].cors_config
request_headers = Headers.from_scope(scope=scope)
origin = request_headers.get("origin")

if cors_config and origin:
pre_flight_method = request_headers.get("Access-Control-Request-Method")
failures = []

if not cors_config.is_allow_all_methods and (
pre_flight_method and pre_flight_method not in cors_config.allow_methods
):
failures.append("method")

response_headers = cors_config.preflight_headers.copy()

if not cors_config.is_origin_allowed(origin):
failures.append("Origin")
elif response_headers.get("Access-Control-Allow-Origin") != "*":
response_headers["Access-Control-Allow-Origin"] = origin

pre_flight_requested_headers = [
header.strip()
for header in request_headers.get("Access-Control-Request-Headers", "").split(",")
if header.strip()
]

if pre_flight_requested_headers:
if cors_config.is_allow_all_headers:
response_headers["Access-Control-Allow-Headers"] = ", ".join(
sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore
)
elif any(
header.lower() not in cors_config.allow_headers for header in pre_flight_requested_headers
):
failures.append("headers")

return (
Response(
content=f"Disallowed CORS {', '.join(failures)}",
status_code=HTTP_400_BAD_REQUEST,
media_type=MediaType.TEXT,
)
if failures
else Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
media_type=MediaType.TEXT,
headers=response_headers,
)
)

return Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
Expand Down
Empty file.
49 changes: 49 additions & 0 deletions tests/e2e/test_cors/test_cors_allowed_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from litestar import Litestar, get
from litestar.config.cors import CORSConfig
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.testing import TestClient


@get("/headers-test")
async def headers_handler() -> str:
return "Test Successful!"


cors_config = CORSConfig(
allow_methods=["GET"],
allow_origins=["https://allowed-origin.com"],
allow_headers=["X-Custom-Header", "Content-Type"],
)
app = Litestar(route_handlers=[headers_handler], cors_config=cors_config)


def test_cors_with_specific_allowed_headers() -> None:
with TestClient(app) as client:
response = client.options(
"/endpoint",
headers={
"Origin": "https://allowed-origin.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Custom-Header, Content-Type",
},
)
assert response.status_code == HTTP_204_NO_CONTENT
assert "x-custom-header" in response.headers["access-control-allow-headers"]
assert "content-type" in response.headers["access-control-allow-headers"]


def test_cors_with_unauthorized_headers() -> None:
with TestClient(app) as client:
response = client.options(
"/endpoint",
headers={
"Origin": "https://allowed-origin.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Not-Allowed-Header",
},
)
assert response.status_code == HTTP_400_BAD_REQUEST
assert (
"access-control-allow-headers" not in response.headers
or "x-not-allowed-header" not in response.headers.get("access-control-allow-headers", "")
)
39 changes: 39 additions & 0 deletions tests/e2e/test_cors/test_cors_allowed_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from http import HTTPStatus

from litestar import Litestar, get
from litestar.config.cors import CORSConfig
from litestar.testing import TestClient


@get("/method-test")
async def method_handler() -> str:
return "Method Test Successful!"


cors_config = CORSConfig(allow_methods=["GET", "POST"], allow_origins=["https://allowed-origin.com"])
app = Litestar(route_handlers=[method_handler], cors_config=cors_config)


def test_cors_allowed_methods() -> None:
with TestClient(app) as client:
response = client.options(
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"}
)
assert response.status_code == HTTPStatus.NO_CONTENT
assert response.headers["access-control-allow-origin"] == "https://allowed-origin.com"
assert "GET" in response.headers["access-control-allow-methods"]

response = client.options(
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "POST"}
)
assert response.status_code == HTTPStatus.NO_CONTENT
assert "POST" in response.headers["access-control-allow-methods"]


def test_cors_disallowed_methods() -> None:
with TestClient(app) as client:
response = client.options(
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "PUT"}
)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert "PUT" not in response.headers.get("access-control-allow-methods", "")
39 changes: 39 additions & 0 deletions tests/e2e/test_cors/test_cors_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from litestar import Litestar, get
from litestar.config.cors import CORSConfig
from litestar.status_codes import HTTP_204_NO_CONTENT
from litestar.testing import TestClient


@get("/credentials-test")
async def credentials_handler() -> str:
return "Test Successful!"


def test_cors_with_credentials_allowed() -> None:
cors_config = CORSConfig(
allow_methods=["GET"], allow_origins=["https://allowed-origin.com"], allow_credentials=True
)
app = Litestar(route_handlers=[credentials_handler], cors_config=cors_config)

with TestClient(app) as client:
response = client.options(
"/endpoint", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"}
)
assert response.status_code == HTTP_204_NO_CONTENT
assert response.headers["access-control-allow-credentials"] == "true"


def test_cors_with_credentials_disallowed() -> None:
cors_config = CORSConfig(
allow_methods=["GET"],
allow_origins=["https://allowed-origin.com"],
allow_credentials=False, # Credentials should not be allowed
)
app = Litestar(route_handlers=[credentials_handler], cors_config=cors_config)

with TestClient(app) as client:
response = client.options(
"/endpoint", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"}
)
assert response.status_code == HTTP_204_NO_CONTENT
assert "access-control-allow-credentials" not in response.headers
Loading
Loading