-
-
Notifications
You must be signed in to change notification settings - Fork 389
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: CORS pre-flight response generated by middleware.
Move pre-flight response generation from the generated OPTIONS handlers to the CORS middleware. Fixes case where pre-flight response not applied to mounted ASGI apps. Fixes the case where a user defined OPTIONS handler may not include CORS logic. In fact, this is they only way we can be sure that we honor CORS pre-flight if the config object is provided. Adds a stack of CORS e2e tests.
- Loading branch information
1 parent
fb5f744
commit a30484d
Showing
12 changed files
with
455 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from litestar import Litestar, get | ||
from litestar.config.cors import CORSConfig | ||
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST | ||
from litestar.testing import TestClient | ||
|
||
|
||
@get("/headers-test") | ||
async def headers_handler() -> str: | ||
return "Test Successful!" | ||
|
||
|
||
cors_config = CORSConfig( | ||
allow_methods=["GET"], | ||
allow_origins=["https://allowed-origin.com"], | ||
allow_headers=["X-Custom-Header", "Content-Type"], | ||
) | ||
app = Litestar(route_handlers=[headers_handler], cors_config=cors_config) | ||
|
||
|
||
def test_cors_with_specific_allowed_headers() -> None: | ||
with TestClient(app) as client: | ||
response = client.options( | ||
"/endpoint", | ||
headers={ | ||
"Origin": "https://allowed-origin.com", | ||
"Access-Control-Request-Method": "GET", | ||
"Access-Control-Request-Headers": "X-Custom-Header, Content-Type", | ||
}, | ||
) | ||
assert response.status_code == HTTP_204_NO_CONTENT | ||
assert "x-custom-header" in response.headers["access-control-allow-headers"] | ||
assert "content-type" in response.headers["access-control-allow-headers"] | ||
|
||
|
||
def test_cors_with_unauthorized_headers() -> None: | ||
with TestClient(app) as client: | ||
response = client.options( | ||
"/endpoint", | ||
headers={ | ||
"Origin": "https://allowed-origin.com", | ||
"Access-Control-Request-Method": "GET", | ||
"Access-Control-Request-Headers": "X-Not-Allowed-Header", | ||
}, | ||
) | ||
assert response.status_code == HTTP_400_BAD_REQUEST | ||
assert ( | ||
"access-control-allow-headers" not in response.headers | ||
or "x-not-allowed-header" not in response.headers.get("access-control-allow-headers", "") | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from http import HTTPStatus | ||
|
||
from litestar import Litestar, get | ||
from litestar.config.cors import CORSConfig | ||
from litestar.testing import TestClient | ||
|
||
|
||
@get("/method-test") | ||
async def method_handler() -> str: | ||
return "Method Test Successful!" | ||
|
||
|
||
cors_config = CORSConfig(allow_methods=["GET", "POST"], allow_origins=["https://allowed-origin.com"]) | ||
app = Litestar(route_handlers=[method_handler], cors_config=cors_config) | ||
|
||
|
||
def test_cors_allowed_methods() -> None: | ||
with TestClient(app) as client: | ||
response = client.options( | ||
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"} | ||
) | ||
assert response.status_code == HTTPStatus.NO_CONTENT | ||
assert response.headers["access-control-allow-origin"] == "https://allowed-origin.com" | ||
assert "GET" in response.headers["access-control-allow-methods"] | ||
|
||
response = client.options( | ||
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "POST"} | ||
) | ||
assert response.status_code == HTTPStatus.NO_CONTENT | ||
assert "POST" in response.headers["access-control-allow-methods"] | ||
|
||
|
||
def test_cors_disallowed_methods() -> None: | ||
with TestClient(app) as client: | ||
response = client.options( | ||
"/method-test", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "PUT"} | ||
) | ||
assert response.status_code == HTTPStatus.BAD_REQUEST | ||
assert "PUT" not in response.headers.get("access-control-allow-methods", "") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from litestar import Litestar, get | ||
from litestar.config.cors import CORSConfig | ||
from litestar.status_codes import HTTP_204_NO_CONTENT | ||
from litestar.testing import TestClient | ||
|
||
|
||
@get("/credentials-test") | ||
async def credentials_handler() -> str: | ||
return "Test Successful!" | ||
|
||
|
||
def test_cors_with_credentials_allowed() -> None: | ||
cors_config = CORSConfig( | ||
allow_methods=["GET"], allow_origins=["https://allowed-origin.com"], allow_credentials=True | ||
) | ||
app = Litestar(route_handlers=[credentials_handler], cors_config=cors_config) | ||
|
||
with TestClient(app) as client: | ||
response = client.options( | ||
"/endpoint", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"} | ||
) | ||
assert response.status_code == HTTP_204_NO_CONTENT | ||
assert response.headers["access-control-allow-credentials"] == "true" | ||
|
||
|
||
def test_cors_with_credentials_disallowed() -> None: | ||
cors_config = CORSConfig( | ||
allow_methods=["GET"], | ||
allow_origins=["https://allowed-origin.com"], | ||
allow_credentials=False, # Credentials should not be allowed | ||
) | ||
app = Litestar(route_handlers=[credentials_handler], cors_config=cors_config) | ||
|
||
with TestClient(app) as client: | ||
response = client.options( | ||
"/endpoint", headers={"Origin": "https://allowed-origin.com", "Access-Control-Request-Method": "GET"} | ||
) | ||
assert response.status_code == HTTP_204_NO_CONTENT | ||
assert "access-control-allow-credentials" not in response.headers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from http import HTTPStatus | ||
|
||
from litestar import Litestar, get | ||
from litestar.config.cors import CORSConfig | ||
from litestar.exceptions import HTTPException | ||
from litestar.middleware import AbstractMiddleware | ||
from litestar.testing import TestClient | ||
from litestar.types.asgi_types import Receive, Scope, Send | ||
|
||
|
||
class ExceptionMiddleware(AbstractMiddleware): | ||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Intentional Error") | ||
|
||
|
||
@get("/test") | ||
async def handler() -> str: | ||
return "Should not reach this" | ||
|
||
|
||
cors_config = CORSConfig(allow_methods=["GET"], allow_origins=["https://allowed-origin.com"], allow_credentials=True) | ||
app = Litestar(route_handlers=[handler], cors_config=cors_config, middleware=[ExceptionMiddleware]) | ||
|
||
|
||
def test_cors_on_middleware_exception() -> None: | ||
with TestClient(app) as client: | ||
response = client.get("/test", headers={"Origin": "https://allowed-origin.com"}) | ||
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR | ||
assert response.headers["access-control-allow-origin"] == "https://allowed-origin.com" | ||
assert response.headers["access-control-allow-credentials"] == "true" |
Oops, something went wrong.