Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cookie-based auth #1521

Merged
merged 19 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,13 @@ Make sure to run `poetry install` again whenever you've updated the frontend!

1. Find the folder containing the e2e test that you're looking for in `cypress/e2e`.
2. Run `SINGLE_TEST=FOLDER pnpm test` and change FOLDER with the folder from the previous step (example: `SINGLE_TEST=scoped_elements pnpm run test`).

### Headed/debugging

Causes the Electron browser to be shown on screen and keeps it open after tests are done.
Extremely useful for debugging!

```sh
SINGLE_TEST=password_auth CYPRESS_OPTIONS='--headed --no-exit' pnpm test
```

1 change: 1 addition & 0 deletions .github/workflows/lint-backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
- name: Lint with ruff
uses: astral-sh/ruff-action@v1
with:
version: '0.8.0'
src: ${{ env.BACKEND_DIR }}
changed-files: "true"
- name: Check formatting with ruff
Expand Down
50 changes: 18 additions & 32 deletions backend/chainlit/auth.py → backend/chainlit/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import os
from datetime import datetime, timedelta
from typing import Any, Dict

import jwt
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer

from chainlit.config import config
from chainlit.data import get_data_layer
from chainlit.logger import logger
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.user import User

reuseable_oauth = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)
from .cookie import OAuth2PasswordBearerWithCookie
from .jwt import create_jwt, decode_jwt, get_jwt_secret


def get_jwt_secret():
return os.environ.get("CHAINLIT_AUTH_SECRET")
reuseable_oauth = OAuth2PasswordBearerWithCookie(tokenUrl="/login", auto_error=False)


def ensure_jwt_secret():
Expand Down Expand Up @@ -43,55 +38,46 @@ def get_configuration():
"requireLogin": require_login(),
"passwordAuth": config.code.password_auth_callback is not None,
"headerAuth": config.code.header_auth_callback is not None,
"cookieAuth": config.project.cookie_auth,
"oauthProviders": (
get_configured_oauth_providers() if is_oauth_enabled() else []
),
}


def create_jwt(data: User) -> str:
to_encode: Dict[str, Any] = data.to_dict()
to_encode.update(
{
"exp": datetime.utcnow()
+ timedelta(seconds=config.project.user_session_timeout),
}
)
encoded_jwt = jwt.encode(to_encode, get_jwt_secret(), algorithm="HS256")
return encoded_jwt


async def authenticate_user(token: str = Depends(reuseable_oauth)):
try:
dict = jwt.decode(
token,
get_jwt_secret(),
algorithms=["HS256"],
options={"verify_signature": True},
)
del dict["exp"]
user = User(**dict)
user = decode_jwt(token)
except Exception as e:
raise HTTPException(
status_code=401, detail="Invalid authentication token"
) from e

if data_layer := get_data_layer():
# Get or create persistent user if we've a data layer available.
try:
persisted_user = await data_layer.get_user(user.identifier)
if persisted_user is None:
persisted_user = await data_layer.create_user(user)
except Exception:
assert persisted_user
except Exception as e:
logger.exception("Unable to get persisted_user from data layer: %s", e)
return user

if user and user.display_name:
# Copy ephemeral display_name from authenticated user to persistent user.
persisted_user.display_name = user.display_name

return persisted_user
else:
return user

return user


async def get_current_user(token: str = Depends(reuseable_oauth)):
if not require_login():
return None

return await authenticate_user(token)


__all__ = ["create_jwt", "get_configuration", "get_current_user"]
124 changes: 124 additions & 0 deletions backend/chainlit/auth/cookie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
from typing import Literal, Optional, cast

from fastapi import Request, Response
from fastapi.exceptions import HTTPException
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.status import HTTP_401_UNAUTHORIZED

""" Module level cookie settings. """
_cookie_samesite = cast(
Literal["lax", "strict", "none"],
os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax"),
)

assert (
_cookie_samesite
in [
"lax",
"strict",
"none",
]
), "Invalid value for CHAINLIT_COOKIE_SAMESITE. Must be one of 'lax', 'strict' or 'none'."
_cookie_secure = _cookie_samesite == "none"

_auth_cookie_lifetime = 60 * 60 # 1 hour
_state_cookie_lifetime = 3 * 60 # 3m
_auth_cookie_name = "access_token"
_state_cookie_name = "oauth_state"


class OAuth2PasswordBearerWithCookie(SecurityBase):
"""
OAuth2 password flow with cookie support with fallback to bearer token.
"""

def __init__(
self,
tokenUrl: str,
scheme_name: Optional[str] = None,
auto_error: bool = True,
):
self.tokenUrl = tokenUrl
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
# First try to get the token from the cookie
token = request.cookies.get(_auth_cookie_name)

# If no cookie, try the Authorization header as fallback
if not token:
# TODO: Only bother to check if cookie auth is explicitly disabled.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we leaving this todo?

authorization = request.headers.get("Authorization")
if authorization:
scheme, token = get_authorization_scheme_param(authorization)
if scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None
else:
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None

return token


def set_auth_cookie(response: Response, token: str):
"""
Helper function to set the authentication cookie with secure parameters
"""

response.set_cookie(
key=_auth_cookie_name,
value=token,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=_auth_cookie_lifetime,
path="/", # Why is path set here and not below?
)


def clear_auth_cookie(response: Response):
"""
Helper function to clear the authentication cookie
"""
response.delete_cookie(key=_auth_cookie_name, path="/")


def set_oauth_state_cookie(response: Response, token: str):
response.set_cookie(
_state_cookie_name,
token,
httponly=True,
samesite=_cookie_samesite,
secure=_cookie_secure,
max_age=_state_cookie_lifetime,
)


def validate_oauth_state_cookie(request: Request, state: str):
"""Check the state from the oauth provider against the browser cookie."""

oauth_state = request.cookies.get(_state_cookie_name)

if oauth_state != state:
raise Exception("oauth state does not correspond")


def clear_oauth_state_cookie(response: Response):
"""Oauth complete, delete state token."""
response.delete_cookie(_state_cookie_name) # Do we set path here?
37 changes: 37 additions & 0 deletions backend/chainlit/auth/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import datetime
import os
from typing import Any, Dict, Optional

import jwt as pyjwt

from chainlit.config import config
from chainlit.user import User


def get_jwt_secret() -> Optional[str]:
return os.environ.get("CHAINLIT_AUTH_SECRET")


def create_jwt(data: User) -> str:
to_encode: Dict[str, Any] = data.to_dict()
to_encode.update(
{
"exp": datetime.datetime.utcnow()
+ datetime.timedelta(seconds=config.project.user_session_timeout),
}
)
secret = get_jwt_secret()
assert secret
encoded_jwt = pyjwt.encode(to_encode, secret, algorithm="HS256")
return encoded_jwt


def decode_jwt(token: str) -> User:
dict = pyjwt.decode(
token,
get_jwt_secret(),
algorithms=["HS256"],
options={"verify_signature": True},
)
del dict["exp"]
return User(**dict)
15 changes: 14 additions & 1 deletion backend/chainlit/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
nest_asyncio.apply()

# ruff: noqa: E402
from chainlit.auth import ensure_jwt_secret
from chainlit.cache import init_lc_cache
from chainlit.config import (
BACKEND_ROOT,
Expand All @@ -24,7 +25,18 @@
from chainlit.markdown import init_markdown
from chainlit.secret import random_secret
from chainlit.telemetry import trace_event
from chainlit.utils import check_file, ensure_jwt_secret
from chainlit.utils import check_file


def assert_app():
if (
not config.code.on_chat_start
and not config.code.on_message
and not config.code.on_audio_chunk
):
raise Exception(
"You need to configure at least one of on_chat_start, on_message or on_audio_chunk callback"
)


# Create the main command group for Chainlit CLI
Expand Down Expand Up @@ -66,6 +78,7 @@ def run_chainlit(target: str):
load_module(config.run.module_name)

ensure_jwt_secret()
assert_app()

# Create the chainlit.md file if it doesn't exist
init_markdown(config.root)
Expand Down
5 changes: 5 additions & 0 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@
# Allow users to edit their own messages
edit_message = true

# Use httponly cookie for client->server authentication, required to be able to use file upload and elements.
cookie_auth = true

# Authorize users to spontaneously upload files with messages
[features.spontaneous_file_upload]
enabled = true
Expand Down Expand Up @@ -327,6 +330,8 @@ class ProjectSettings(DataClassJsonMixin):
cache: bool = False
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
follow_symlink: bool = False
# Use httponly cookie for client->server authentication, required to be able to use file upload and elements.
cookie_auth: bool = True


@dataclass()
Expand Down
Loading
Loading