Skip to content

Commit

Permalink
Config/baseurl (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
willydouhard authored Sep 29, 2023
1 parent 8e6e46f commit c4af86e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
8 changes: 7 additions & 1 deletion backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@
# Specify a CSS file that can be used to customize the user interface.
# The CSS file can be served from the public directory or via an external link.
# custom_css = '/public/test.css'
# custom_css = "/public/test.css"
# If the app is served behind a reverse proxy (like cloud run) we need to know the base url for oauth
# base_url = "https://mydomain.com"
# Override default MUI light theme. (Check theme.ts)
[UI.theme.light]
Expand Down Expand Up @@ -150,6 +153,9 @@ class UISettings(DataClassJsonMixin):
theme: Optional[Theme] = None
# Optional custom CSS file that allows you to customize the UI
custom_css: Optional[str] = None
# If the app is served behind a reverse proxy (like cloud run) we need to know the base url for oauth
# Example: https://mydomain.com
base_url: Optional[str] = None


@dataclass()
Expand Down
27 changes: 25 additions & 2 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.staticfiles import StaticFiles
from fastapi_socketio import SocketManager
from starlette.datastructures import URL
from starlette.middleware.cors import CORSMiddleware
from typing_extensions import Annotated
from watchfiles import awatch
Expand Down Expand Up @@ -206,6 +207,28 @@ def get_html_template():
return content


def get_user_facing_url(url: URL):
"""
Return the user facing URL for a given URL.
Handles deployment with proxies (like cloud run).
"""
url = url.replace(query="", fragment="")

# No config, we keep the URL as is
if not config.ui.base_url:
return url.__str__()

config_url = URL(config.ui.base_url).replace(
query="",
fragment="",
)
# Remove trailing slash from config URL
if config_url.path.endswith("/"):
config_url = config_url.replace(path=config_url.path[:-1])

return config_url.__str__() + url.path


@app.get("/auth/config")
async def auth(request: Request):
return get_configuration()
Expand Down Expand Up @@ -281,7 +304,7 @@ async def oauth_login(provider_id: str, request: Request):
params = urllib.parse.urlencode(
{
"client_id": provider.client_id,
"redirect_uri": f"{request.url}/callback",
"redirect_uri": f"{get_user_facing_url(request.url)}/callback",
"state": random,
**provider.authorize_params,
}
Expand Down Expand Up @@ -340,7 +363,7 @@ async def oauth_callback(
detail="Unauthorized",
)

url = request.url.replace(query="").__str__()
url = get_user_facing_url(request.url)
token = await provider.get_token(code, url)

(raw_user_data, default_app_user) = await provider.get_user_info(token)
Expand Down

0 comments on commit c4af86e

Please sign in to comment.