diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 983e17182b0..f8b07d15b5a 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -5,6 +5,8 @@ from datetime import timezone from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Dict +from typing import List from typing import Optional from typing import Tuple @@ -15,9 +17,11 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Query from fastapi import Request from fastapi import Response from fastapi import status +from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager from fastapi_users import exceptions @@ -31,8 +35,19 @@ from fastapi_users.authentication import Strategy from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy +from fastapi_users.exceptions import UserAlreadyExists +from fastapi_users.jwt import decode_jwt +from fastapi_users.jwt import generate_jwt +from fastapi_users.jwt import SecretType +from fastapi_users.manager import UserManagerDependency from fastapi_users.openapi import OpenAPIResponseType +from fastapi_users.router.common import ErrorCode +from fastapi_users.router.common import ErrorModel from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase +from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback +from httpx_oauth.oauth2 import BaseOAuth2 +from httpx_oauth.oauth2 import OAuth2Token +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import attributes from sqlalchemy.orm import Session @@ -298,7 +313,7 @@ async def oauth_callback( token = None async with get_async_session_with_tenant(tenant_id) as db_session: token = current_tenant_id.set(tenant_id) - # Print a list of tables in the current database session + verify_email_in_whitelist(account_email, tenant_id) verify_email_domain(account_email) if MULTI_TENANT: @@ -422,7 +437,6 @@ async def authenticate( email = credentials.username # Get tenant_id from mapping table - tenant_id = get_tenant_id_for_email(email) if not tenant_id: # User not found in mapping @@ -654,3 +668,186 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User def get_default_admin_user_emails_() -> list[str]: # No default seeding available for Danswer MIT return [] + + +STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state" + + +class OAuth2AuthorizeResponse(BaseModel): + authorization_url: str + + +def generate_state_token( + data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600 +) -> str: + data["aud"] = STATE_TOKEN_AUDIENCE + + return generate_jwt(data, secret, lifetime_seconds) + + +# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91 + + +def create_danswer_oauth_router( + oauth_client: BaseOAuth2, + backend: AuthenticationBackend, + state_secret: SecretType, + redirect_url: Optional[str] = None, + associate_by_email: bool = False, + is_verified_by_default: bool = False, +) -> APIRouter: + return get_oauth_router( + oauth_client, + backend, + get_user_manager, + state_secret, + redirect_url, + associate_by_email, + is_verified_by_default, + ) + + +def get_oauth_router( + oauth_client: BaseOAuth2, + backend: AuthenticationBackend, + get_user_manager: UserManagerDependency[models.UP, models.ID], + state_secret: SecretType, + redirect_url: Optional[str] = None, + associate_by_email: bool = False, + is_verified_by_default: bool = False, +) -> APIRouter: + """Generate a router with the OAuth routes.""" + router = APIRouter() + callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback" + + if redirect_url is not None: + oauth2_authorize_callback = OAuth2AuthorizeCallback( + oauth_client, + redirect_url=redirect_url, + ) + else: + oauth2_authorize_callback = OAuth2AuthorizeCallback( + oauth_client, + route_name=callback_route_name, + ) + + @router.get( + "/authorize", + name=f"oauth:{oauth_client.name}.{backend.name}.authorize", + response_model=OAuth2AuthorizeResponse, + ) + async def authorize( + request: Request, scopes: List[str] = Query(None) + ) -> OAuth2AuthorizeResponse: + if redirect_url is not None: + authorize_redirect_url = redirect_url + else: + authorize_redirect_url = str(request.url_for(callback_route_name)) + + next_url = request.query_params.get("next", "/") + state_data: Dict[str, str] = {"next_url": next_url} + state = generate_state_token(state_data, state_secret) + authorization_url = await oauth_client.get_authorization_url( + authorize_redirect_url, + state, + scopes, + ) + + return OAuth2AuthorizeResponse(authorization_url=authorization_url) + + @router.get( + "/callback", + name=callback_route_name, + description="The response varies based on the authentication backend used.", + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": ErrorModel, + "content": { + "application/json": { + "examples": { + "INVALID_STATE_TOKEN": { + "summary": "Invalid state token.", + "value": None, + }, + ErrorCode.LOGIN_BAD_CREDENTIALS: { + "summary": "User is inactive.", + "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS}, + }, + } + } + }, + }, + }, + ) + async def callback( + request: Request, + access_token_state: Tuple[OAuth2Token, str] = Depends( + oauth2_authorize_callback + ), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), + ) -> RedirectResponse: + token, state = access_token_state + account_id, account_email = await oauth_client.get_id_email( + token["access_token"] + ) + + if account_email is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL, + ) + + try: + state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) + except jwt.DecodeError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + next_url = state_data.get("next_url", "/") + + # Authenticate user + try: + user = await user_manager.oauth_callback( + oauth_client.name, + token["access_token"], + account_id, + account_email, + token.get("expires_at"), + token.get("refresh_token"), + request, + associate_by_email=associate_by_email, + is_verified_by_default=is_verified_by_default, + ) + except UserAlreadyExists: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS, + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.LOGIN_BAD_CREDENTIALS, + ) + + # Login user + response = await backend.login(strategy, user) + await user_manager.on_after_login(user, request, response) + + # Prepare redirect response + redirect_response = RedirectResponse(next_url, status_code=302) + + # Copy headers and other attributes from 'response' to 'redirect_response' + for header_name, header_value in response.headers.items(): + redirect_response.headers[header_name] = header_value + + if hasattr(response, "body"): + redirect_response.body = response.body + if hasattr(response, "status_code"): + redirect_response.status_code = response.status_code + if hasattr(response, "media_type"): + redirect_response.media_type = response.media_type + + return redirect_response + + return router diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 151f852486c..cd0c5c195a6 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -81,7 +81,6 @@ router as token_rate_limit_settings_router, ) from danswer.setup import setup_danswer -from danswer.setup import setup_multitenant_danswer from danswer.utils.logger import setup_logger from danswer.utils.telemetry import get_or_generate_uuid from danswer.utils.telemetry import optional_telemetry @@ -176,12 +175,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # We cache this at the beginning so there is no delay in the first telemetry get_or_generate_uuid() + # If we are multi-tenant, we need to only set up initial public tables with Session(engine) as db_session: setup_danswer(db_session) - else: - setup_multitenant_danswer() - optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) yield diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 4584e06a00b..2d1793b8b27 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -2,6 +2,7 @@ from httpx_oauth.clients.openid import OpenID from danswer.auth.users import auth_backend +from danswer.auth.users import create_danswer_oauth_router from danswer.auth.users import fastapi_users from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import MULTI_TENANT @@ -61,7 +62,7 @@ def get_application() -> FastAPI: if AUTH_TYPE == AuthType.OIDC: include_router_with_global_prefix_prepended( application, - fastapi_users.get_oauth_router( + create_danswer_oauth_router( OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL), auth_backend, USER_AUTH_SECRET, diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx index 9ec047d61e2..e47bdb420be 100644 --- a/web/src/app/auth/login/page.tsx +++ b/web/src/app/auth/login/page.tsx @@ -11,7 +11,7 @@ import { SignInButton } from "./SignInButton"; import { EmailPasswordForm } from "./EmailPasswordForm"; import { Card, Title, Text } from "@tremor/react"; import Link from "next/link"; -import { Logo } from "@/components/Logo"; + import { LoginText } from "./LoginText"; import { getSecondsUntilExpiration } from "@/lib/time"; import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; @@ -37,6 +37,10 @@ const Page = async ({ console.log(`Some fetch failed for the login page - ${e}`); } + const nextUrl = Array.isArray(searchParams?.next) + ? searchParams?.next[0] + : searchParams?.next || null; + // simply take the user to the home page if Auth is disabled if (authTypeMetadata?.authType === "disabled") { return redirect("/"); @@ -59,7 +63,7 @@ const Page = async ({ let authUrl: string | null = null; if (authTypeMetadata) { try { - authUrl = await getAuthUrlSS(authTypeMetadata.authType); + authUrl = await getAuthUrlSS(authTypeMetadata.authType, nextUrl!); } catch (e) { console.log(`Some fetch failed for the login page - ${e}`); } @@ -88,6 +92,7 @@ const Page = async ({ /> )} + {authTypeMetadata?.authType === "basic" && (
diff --git a/web/src/app/auth/oauth/callback/route.ts b/web/src/app/auth/oauth/callback/route.ts index 6e8f290a65f..ca5a82743d3 100644 --- a/web/src/app/auth/oauth/callback/route.ts +++ b/web/src/app/auth/oauth/callback/route.ts @@ -8,7 +8,8 @@ export const GET = async (request: NextRequest) => { const url = new URL(buildUrl("/auth/oauth/callback")); url.search = request.nextUrl.search; - const response = await fetch(url.toString()); + // Set 'redirect' to 'manual' to prevent automatic redirection + const response = await fetch(url.toString(), { redirect: "manual" }); const setCookieHeader = response.headers.get("set-cookie"); if (response.status === 401) { @@ -21,9 +22,13 @@ export const GET = async (request: NextRequest) => { return NextResponse.redirect(new URL("/auth/error", getDomain(request))); } + // Get the redirect URL from the backend's 'Location' header, or default to '/' + const redirectUrl = response.headers.get("location") || "/"; + const redirectResponse = NextResponse.redirect( - new URL("/", getDomain(request)) + new URL(redirectUrl, getDomain(request)) ); + redirectResponse.headers.set("set-cookie", setCookieHeader); return redirectResponse; }; diff --git a/web/src/app/auth/oidc/callback/route.ts b/web/src/app/auth/oidc/callback/route.ts index 353119409b9..1bdf2b61db1 100644 --- a/web/src/app/auth/oidc/callback/route.ts +++ b/web/src/app/auth/oidc/callback/route.ts @@ -7,17 +7,27 @@ export const GET = async (request: NextRequest) => { // which adds back a redirect to the main app. const url = new URL(buildUrl("/auth/oidc/callback")); url.search = request.nextUrl.search; - - const response = await fetch(url.toString()); + // Set 'redirect' to 'manual' to prevent automatic redirection + const response = await fetch(url.toString(), { redirect: "manual" }); const setCookieHeader = response.headers.get("set-cookie"); + if (response.status === 401) { + return NextResponse.redirect( + new URL("/auth/create-account", getDomain(request)) + ); + } + if (!setCookieHeader) { return NextResponse.redirect(new URL("/auth/error", getDomain(request))); } + // Get the redirect URL from the backend's 'Location' header, or default to '/' + const redirectUrl = response.headers.get("location") || "/"; + const redirectResponse = NextResponse.redirect( - new URL("/", getDomain(request)) + new URL(redirectUrl, getDomain(request)) ); + redirectResponse.headers.set("set-cookie", setCookieHeader); return redirectResponse; }; diff --git a/web/src/app/page.tsx b/web/src/app/page.tsx index 9cc0c56c5e2..00776084a59 100644 --- a/web/src/app/page.tsx +++ b/web/src/app/page.tsx @@ -3,7 +3,6 @@ import { redirect } from "next/navigation"; export default async function Page() { const settings = await fetchSettingsSS(); - if (!settings) { redirect("/search"); } diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index 818dfe6b965..3d26053dd41 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -36,8 +36,13 @@ import WrappedSearch from "./WrappedSearch"; import { SearchProvider } from "@/components/context/SearchContext"; import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs"; import { LLMProviderDescriptor } from "../admin/configuration/llm/interfaces"; +import { headers } from "next/headers"; -export default async function Home() { +export default async function Home({ + searchParams, +}: { + searchParams: { [key: string]: string | string[] | undefined }; +}) { // Disable caching so we always get the up to date connector / document set / persona info // importantly, this prevents users from adding a connector, going back to the main page, // and then getting hit with a "No Connectors" popup @@ -82,8 +87,17 @@ export default async function Home() { const llmProviders = (results[7] || []) as LLMProviderDescriptor[]; const authDisabled = authTypeMetadata?.authType === "disabled"; + if (!authDisabled && !user) { - return redirect("/auth/login"); + const headersList = headers(); + const fullUrl = headersList.get("x-url") || "/search"; + const searchParamsString = new URLSearchParams( + searchParams as unknown as Record + ).toString(); + const redirectUrl = searchParamsString + ? `${fullUrl}?${searchParamsString}` + : fullUrl; + return redirect(`/auth/login?next=${encodeURIComponent(redirectUrl)}`); } if (user && !user.is_verified && authTypeMetadata?.requiresVerification) { diff --git a/web/src/components/UserDropdown.tsx b/web/src/components/UserDropdown.tsx index e59e7291c52..00c3b83c1fa 100644 --- a/web/src/components/UserDropdown.tsx +++ b/web/src/components/UserDropdown.tsx @@ -3,7 +3,7 @@ import { useState, useRef, useContext, useEffect, useMemo } from "react"; import { FiLogOut } from "react-icons/fi"; import Link from "next/link"; -import { useRouter } from "next/navigation"; +import { useRouter, usePathname, useSearchParams } from "next/navigation"; import { User, UserRole } from "@/lib/types"; import { checkUserIsNoAuthUser, logout } from "@/lib/user"; import { Popover } from "./popover/Popover"; @@ -65,6 +65,8 @@ export function UserDropdown({ const [userInfoVisible, setUserInfoVisible] = useState(false); const userInfoRef = useRef(null); const router = useRouter(); + const pathname = usePathname(); + const searchParams = useSearchParams(); const combinedSettings = useContext(SettingsContext); const customNavItems: NavigationItem[] = useMemo( @@ -87,8 +89,17 @@ export function UserDropdown({ logout().then((isSuccess) => { if (!isSuccess) { alert("Failed to logout"); + return; } - router.push("/auth/login"); + + // Construct the current URL + const currentUrl = `${pathname}${searchParams.toString() ? `?${searchParams.toString()}` : ""}`; + + // Encode the current URL to use as a redirect parameter + const encodedRedirect = encodeURIComponent(currentUrl); + + // Redirect to login page with the current page as a redirect parameter + router.push(`/auth/login?next=${encodedRedirect}`); }); }; diff --git a/web/src/lib/chat/fetchChatData.ts b/web/src/lib/chat/fetchChatData.ts index 144a839cd73..c4188720689 100644 --- a/web/src/lib/chat/fetchChatData.ts +++ b/web/src/lib/chat/fetchChatData.ts @@ -20,7 +20,7 @@ import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; import { Folder } from "@/app/chat/folders/interfaces"; import { personaComparator } from "@/app/admin/assistants/lib"; -import { cookies } from "next/headers"; +import { cookies, headers } from "next/headers"; import { SIDEBAR_TOGGLED_COOKIE_NAME, DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME, @@ -29,6 +29,7 @@ import { hasCompletedWelcomeFlowSS } from "@/components/initialSetup/welcome/Wel import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS"; import { NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN } from "../constants"; import { checkLLMSupportsImageInput } from "../llm/utils"; +import { redirect } from "next/navigation"; interface FetchChatDataResult { user: User | null; @@ -98,7 +99,15 @@ export async function fetchChatData(searchParams: { const authDisabled = authTypeMetadata?.authType === "disabled"; if (!authDisabled && !user) { - return { redirect: "/auth/login" }; + const headersList = headers(); + const fullUrl = headersList.get("x-url") || "/chat"; + const searchParamsString = new URLSearchParams( + searchParams as unknown as Record + ).toString(); + const redirectUrl = searchParamsString + ? `${fullUrl}?${searchParamsString}` + : fullUrl; + return redirect(`/auth/login?next=${encodeURIComponent(redirectUrl)}`); } if (user && !user.is_verified && authTypeMetadata?.requiresVerification) { diff --git a/web/src/lib/userSS.ts b/web/src/lib/userSS.ts index c1c5fc9d60e..81261cebea0 100644 --- a/web/src/lib/userSS.ts +++ b/web/src/lib/userSS.ts @@ -40,8 +40,12 @@ export const getAuthDisabledSS = async (): Promise => { return (await getAuthTypeMetadataSS()).authType === "disabled"; }; -const geOIDCAuthUrlSS = async (): Promise => { - const res = await fetch(buildUrl("/auth/oidc/authorize")); +const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise => { + const res = await fetch( + buildUrl( + `/auth/oidc/authorize${nextUrl ? `?next=${encodeURIComponent(nextUrl)}` : ""}` + ) + ); if (!res.ok) { throw new Error("Failed to fetch data"); } @@ -51,7 +55,7 @@ const geOIDCAuthUrlSS = async (): Promise => { }; const getGoogleOAuthUrlSS = async (): Promise => { - const res = await fetch(buildUrl("/auth/oauth/authorize")); + const res = await fetch(buildUrl(`/auth/oauth/authorize`)); if (!res.ok) { throw new Error("Failed to fetch data"); } @@ -70,7 +74,10 @@ const getSAMLAuthUrlSS = async (): Promise => { return data.authorization_url; }; -export const getAuthUrlSS = async (authType: AuthType): Promise => { +export const getAuthUrlSS = async ( + authType: AuthType, + nextUrl: string | null +): Promise => { // Returns the auth url for the given auth type switch (authType) { case "disabled": @@ -84,7 +91,7 @@ export const getAuthUrlSS = async (authType: AuthType): Promise => { return await getSAMLAuthUrlSS(); } case "oidc": { - return await geOIDCAuthUrlSS(); + return await getOIDCAuthUrlSS(nextUrl); } } };