Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for csrf tokens in html forms #18

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,48 @@ app = FastAPI()
app.add_middleware(CSRFMiddleware, secret="__CHANGE_ME__")
```

## Usage with FastAPI and HTML forms

Add the starlette_csrf middleware and utilize the following template processor in your FastAPI code:

```py
import typing
from fastapi.templating import Jinja2Templates
from fastapi import Request
from app.core.config import settings

def csrf_token_processor(request: Request) -> typing.Dict[str, typing.Any]:
csrf_token = request.cookies.get(settings.CSRF_COOKIE_NAME)
csrf_input = f'<input type="hidden" name="X-CSRF-Token" value="{csrf_token}">'
csrf_header = {settings.CSRF_HEADER_NAME: csrf_token}
return {
'csrf_token': csrf_token,
'csrf_input': csrf_input,
'csrf_header': csrf_header
}

templates = Jinja2Templates(directory="templates", context_processors=[csrf_token_processor])
```

Simply using {{ csrf_input | safe }} in each form is now sufficient to ensure a more secure web application. For example:

```html
<form method="post">
{{ csrf_input | safe }}
<!-- Other form fields here -->
<button type="submit">Submit</button>
</form>
```

Furthermore, we can use {{ csrf_header }} in HTMX requests. For example:

```html
<form hx-patch="/route/edit" hx-headers='{{ csrf_header | tojson | safe }}' hx-trigger="submit" hx-target="#yourtarget" hx-swap="outerHTML" >
<!-- Other form fields here -->
<button type="submit">Submit</button>
</form>
```

## Arguments

* `secret` (`str`): Secret to sign the CSRF token value. **Be sure to choose a strong passphrase and keep it SECRET**.
Expand Down
43 changes: 32 additions & 11 deletions starlette_csrf/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import secrets
from re import Pattern
from typing import Dict, List, Optional, Set, cast

from itsdangerous import BadSignature
from itsdangerous.url_safe import URLSafeSerializer
from starlette.datastructures import URL, MutableHeaders
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send


class CSRFMiddleware:
def __init__(
self,
Expand Down Expand Up @@ -46,19 +44,21 @@ def __init__(
self.header_name = header_name

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
return

request = Request(scope)
csrf_cookie = request.cookies.get(self.cookie_name)
request = Request(scope, receive)
body = await self._get_request_body(request)
csrf_cookie = request.cookies.get(self.cookie_name)

if self._url_is_required(request.url) or (
request.method not in self.safe_methods
and not self._url_is_exempt(request.url)
and self._has_sensitive_cookies(request.cookies)
):
submitted_csrf_token = await self._get_submitted_csrf_token(request)

if (
not csrf_cookie
or not submitted_csrf_token
Expand All @@ -67,9 +67,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = self._get_error_response(request)
await response(scope, receive, send)
return


request._receive = self._receive_with_body(request._receive, body)
send = functools.partial(self.send, send=send, scope=scope)
await self.app(scope, receive, send)
await self.app(scope, request.receive, send)

async def send(self, message: Message, send: Send, scope: Scope) -> None:
request = Request(scope)
Expand All @@ -89,9 +90,9 @@ async def send(self, message: Message, send: Send, scope: Scope) -> None:
if self.cookie_domain is not None:
cookie[cookie_name]["domain"] = self.cookie_domain # pragma: no cover
headers.append("set-cookie", cookie.output(header="").strip())

await send(message)

def _has_sensitive_cookies(self, cookies: Dict[str, str]) -> bool:
if not self.sensitive_cookies:
return True
Expand All @@ -115,17 +116,32 @@ def _url_is_exempt(self, url: URL) -> bool:
if exempt_url.match(url.path):
return True
return False

async def _get_request_body(self, request: Request):
if request.method in ("POST", "PUT", "PATCH", "DELETE"):
return await request.body()
return b""

async def _get_submitted_csrf_token(self, request: Request) -> Optional[str]:
return request.headers.get(self.header_name)
csrf_token_header = request.headers.get(self.header_name)
if csrf_token_header:
return csrf_token_header

csrftoken_form = await self._get_csrf_token_form(request)
return csrftoken_form

async def _get_csrf_token_form(self, request: Request) -> str:
form = await request.form()
csrf_token = form.get(self.cookie_name, "")
return csrf_token

def _generate_csrf_token(self) -> str:
return cast(str, self.serializer.dumps(secrets.token_urlsafe(128)))

def _csrf_tokens_match(self, token1: str, token2: str) -> bool:
try:
decoded1: str = self.serializer.loads(token1)
decoded2: str = self.serializer.loads(token2)
decoded2: str = self.serializer.loads(token2)
return secrets.compare_digest(decoded1, decoded2)
except BadSignature:
return False
Expand All @@ -134,3 +150,8 @@ def _get_error_response(self, request: Request) -> Response:
return PlainTextResponse(
content="CSRF token verification failed", status_code=403
)

def _receive_with_body(self, receive, body):
async def inner():
return {"type": "http.request", "body": body, "more_body": False}
return inner