diff --git a/README.md b/README.md index 598d1e37852..3b07d201d29 100644 --- a/README.md +++ b/README.md @@ -337,6 +337,49 @@ We currently support [IDEFICS](https://huggingface.co/blog/idefics) (hosted on T } ``` +#### Group-based Model Permissions + +If [logging in with OpenID](#openid-connect) via a supported provider, then user groups can be used in combination with the `allowed_groups` field for each model to show/hide models to users based on their group membership. + +For all providers, see the following. Then, see additional instructions for your provider below. + +1. Add `PROVIDER: ""` to your `.env.local` (you will enter the actual provider name later). Also, add `groups` to the `OPENID_CONFIG.SCOPES` field in your `.env.local` file: +```env +OPENID_CONFIG=`{ + // rest of OPENID_CONFIG here + PROVIDER: "", + SCOPES: "openid profile groups", + // rest of OPENID_CONFIG here +}` +``` + +2. Use the `allowed_groups` parameter for each model to specify which group(s) should have access to that model. If not specified, all users will be able to access the model. + +> [!WARNING] +> The first model in your `.env.local` file is considered the "default" model and should be available to all users, so we strongly recommend against setting `allowed_groups` for this model. + +> Note that during development, it is common to have `APP_BASE=""` in your `.env.local` - however, due to the cookies created by using a provider, this value should not be empty (e.g. setting `APP_BASE="/"` in `.env.local` would work). + +#### Provider: Microsoft Entra + +In order to enable use of [Microsoft Entra Security Groups](https://learn.microsoft.com/en-us/entra/fundamentals/concept-learn-about-groups) to show/hide models, do the following: + +1. Replace `` with `entra` in `.env.local`. + +2. `allowed_groups` for each model in `.env.local` should be a list of Microsoft Entra **Group IDs** (not group names), e.g.: + +```env +{ +// rest of the model config here +"allowed_groups": ["123abcde-1234-abcd-cdef-1234567890ab", "abcde123-abcd-1234-cdef-abcdef123456"] +} +``` + +3. Finally, configure your app in Microsoft Entra so that the app can access user groups via the MS Graph API: + - [Add groups claim](https://learn.microsoft.com/en-gb/entra/identity-platform/optional-claims?tabs=appui#configure-groups-optional-claims) to your app + - [Enable ID Tokens](https://learn.microsoft.com/en-us/entra/identity-platform/v2-protocols-oidc#enable-id-tokens) for your app + + #### Running your own models using a custom endpoint If you want to, instead of hitting models on the Hugging Face Inference API, you can run your own models locally. diff --git a/src/app.d.ts b/src/app.d.ts index 40a38728d8f..a4c64e90f22 100644 --- a/src/app.d.ts +++ b/src/app.d.ts @@ -10,7 +10,7 @@ declare global { // interface Error {} interface Locals { sessionId: string; - user?: User & { logoutDisabled?: boolean }; + user?: User & { logoutDisabled?: boolean; groups?: string[] }; } interface Error { diff --git a/src/hooks.server.ts b/src/hooks.server.ts index 8e86aa12c9b..61ce921c22c 100644 --- a/src/hooks.server.ts +++ b/src/hooks.server.ts @@ -9,6 +9,8 @@ import { sha256 } from "$lib/utils/sha256"; import { addWeeks } from "date-fns"; import { checkAndRunMigrations } from "$lib/migrations/migrations"; import { building } from "$app/environment"; +import { logout, OIDConfig, ProviderCookieNames } from "$lib/server/auth"; +import { type AccessToken, providers } from "$lib/server/providers/providers"; import { logger } from "$lib/server/logger"; import { AbortedGenerations } from "$lib/server/abortedGenerations"; import { MetricsServer } from "$lib/server/metrics"; @@ -229,7 +231,11 @@ export const handle: Handle = async ({ event, resolve }) => { ...(envPublic.PUBLIC_ORIGIN ? [new URL(envPublic.PUBLIC_ORIGIN).host] : []), ]; - if (!validOrigins.includes(new URL(origin).host)) { + // origin is null for some reason when the POST request callback comes from an auth provider like MS entra so we skip this check (CSRF token is still validated) + if ( + event.url.pathname !== `${base}/login/callback` && + !validOrigins.includes(new URL(origin).host) + ) { return errorResponse(403, "Invalid referer for POST request"); } } @@ -278,6 +284,55 @@ export const handle: Handle = async ({ event, resolve }) => { } } + // Get user groups for allowed models + if (OIDConfig.PROVIDER && OIDConfig.SCOPES.includes("groups")) { + const provider = providers[OIDConfig.PROVIDER]; + const session_exists = event.cookies.get(env.COOKIE_NAME) !== undefined; + + let accessToken: AccessToken = JSON.parse( + event.cookies.get(ProviderCookieNames.ACCESS_TOKEN)?.toString() || "{}" + ); + let providerParameters = JSON.parse( + event.cookies.get(ProviderCookieNames.PROVIDER_PARAMS)?.toString() || "{}" + ); + + // If user is logged in, get/refresh access token and use it to retrieve user groups + if (event.locals.user) { + // Get access token upon login with id token + if (accessToken && providerParameters.idToken) { + [accessToken, providerParameters] = await provider.getAccessToken( + event.cookies, + providerParameters + ); + event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters); + } + // Refresh access token on subsequent requests + else if (accessToken.refreshToken && providerParameters.userTid) { + accessToken = await provider.refreshAccessToken( + event.cookies, + accessToken, + providerParameters + ); + event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters); + } + // Logout user automatically if session exists but access token and/or provider params cookies have expired + else if (session_exists) { + event.locals.user.groups = undefined; + await logout(event.cookies, event.locals); + } + } + } else if (OIDConfig.SCOPES.includes("groups")) { + return errorResponse( + 500, + "'groups' has been set in OPENID_CONFIG.SCOPES, but OPENID_CONFIG.PROVIDER is undefined in .env file" + ); + } else if (OIDConfig.PROVIDER) { + return errorResponse( + 500, + "OPENID_CONFIG.PROVIDER has been set, but 'groups' scope not set in OPENID_CONFIG.SCOPES in .env file" + ); + } + let replaced = false; const response = await resolve(event, { diff --git a/src/lib/components/AssistantSettings.svelte b/src/lib/components/AssistantSettings.svelte index 6c1d2728550..ed32f1630f3 100644 --- a/src/lib/components/AssistantSettings.svelte +++ b/src/lib/components/AssistantSettings.svelte @@ -265,7 +265,7 @@ class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2" bind:value={modelId} > - {#each models.filter((model) => !model.unlisted) as model} + {#each models as model} {/each}

{getError("modelId", form)}

diff --git a/src/lib/components/NavMenu.svelte b/src/lib/components/NavMenu.svelte index 497e7e2115c..08d6a910e22 100644 --- a/src/lib/components/NavMenu.svelte +++ b/src/lib/components/NavMenu.svelte @@ -43,7 +43,7 @@ older: "Older", } as const; - const nModels: number = $page.data.models.filter((el: Model) => !el.unlisted).length; + const nModels: number = $page.data.models.length;
diff --git a/src/lib/server/auth.ts b/src/lib/server/auth.ts index ae170a8bc3f..09e10c57cae 100644 --- a/src/lib/server/auth.ts +++ b/src/lib/server/auth.ts @@ -5,11 +5,13 @@ import { type TokenSet, custom, } from "openid-client"; +import { redirect } from "@sveltejs/kit"; import { addHours, addWeeks } from "date-fns"; import { env } from "$env/dynamic/private"; import { sha256 } from "$lib/utils/sha256"; import { z } from "zod"; import { dev } from "$app/environment"; +import { base } from "$app/paths"; import type { Cookies } from "@sveltejs/kit"; import { collections } from "$lib/server/database"; import JSON5 from "json5"; @@ -17,6 +19,9 @@ import { logger } from "$lib/server/logger"; export interface OIDCSettings { redirectURI: string; + response_type?: string; + response_mode?: string | undefined; + nonce?: string | undefined; } export interface OIDCUserInfo { @@ -34,6 +39,7 @@ export const OIDConfig = z .object({ CLIENT_ID: stringWithDefault(env.OPENID_CLIENT_ID), CLIENT_SECRET: stringWithDefault(env.OPENID_CLIENT_SECRET), + PROVIDER: stringWithDefault(env.OPENID_PROVIDER || ""), PROVIDER_URL: stringWithDefault(env.OPENID_PROVIDER_URL), SCOPES: stringWithDefault(env.OPENID_SCOPES), NAME_CLAIM: stringWithDefault(env.OPENID_NAME_CLAIM).refine( @@ -46,8 +52,15 @@ export const OIDConfig = z }) .parse(JSON5.parse(env.OPENID_CONFIG || "{}")); +export const ProviderCookieNames = { + ACCESS_TOKEN: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-access-token" : "", + PROVIDER_PARAMS: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-params" : "", +}; + export const requiresUser = !!OIDConfig.CLIENT_ID && !!OIDConfig.CLIENT_SECRET; +export const responseType = OIDConfig.SCOPES.includes("groups") ? "code id_token" : "code"; + const sameSite = z .enum(["lax", "none", "strict"]) .default(dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none") @@ -108,7 +121,7 @@ async function getOIDCClient(settings: OIDCSettings): Promise { client_id: OIDConfig.CLIENT_ID, client_secret: OIDConfig.CLIENT_SECRET, redirect_uris: [settings.redirectURI], - response_types: ["code"], + response_types: ["code", "id_token"], [custom.clock_tolerance]: OIDConfig.TOLERANCE || undefined, id_token_signed_response_alg: OIDConfig.ID_TOKEN_SIGNED_RESPONSE_ALG || undefined, }; @@ -131,8 +144,13 @@ export async function getOIDCAuthorizationUrl( return client.authorizationUrl({ scope: OIDConfig.SCOPES, - state: csrfToken, + state: Buffer.from(JSON.stringify({ csrfToken, sessionId: params.sessionId })).toString( + "base64" + ), resource: OIDConfig.RESOURCE || undefined, + response_type: settings.response_type, + response_mode: settings.response_mode, + nonce: settings.nonce, }); } @@ -142,7 +160,11 @@ export async function getOIDCUserData( iss?: string ): Promise { const client = await getOIDCClient(settings); - const token = await client.callback(settings.redirectURI, { code, iss }); + const token = await client.callback( + settings.redirectURI, + { code, iss }, + { nonce: settings.nonce } + ); const userData = await client.userinfo(token); return { token, userData }; @@ -175,3 +197,26 @@ export async function validateAndParseCsrfToken( } return null; } + +export async function logout(cookies: Cookies, locals: App.Locals) { + await collections.sessions.deleteOne({ sessionId: locals.sessionId }); + + const cookie_names = [env.COOKIE_NAME]; + if (ProviderCookieNames.ACCESS_TOKEN) { + cookie_names.push(ProviderCookieNames.ACCESS_TOKEN); + } + if (ProviderCookieNames.PROVIDER_PARAMS) { + cookie_names.push(ProviderCookieNames.PROVIDER_PARAMS); + } + + for (const cookie_name of cookie_names) { + cookies.delete(cookie_name, { + path: env.APP_BASE, + // So that it works inside the space's iframe + sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none", + secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"), + httpOnly: true, + }); + } + redirect(303, `${base}/`); +} diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index a5c4ab49b4f..c6efb872678 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -81,6 +81,7 @@ const modelConfig = z.object({ multimodal: z.boolean().default(false), multimodalAcceptedMimetypes: z.array(z.string()).optional(), tools: z.boolean().default(false), + allowed_groups: z.array(z.string()).optional(), unlisted: z.boolean().default(false), embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(), /** Used to enable/disable system prompt usage */ diff --git a/src/lib/server/providers/microsoft_entra/providerEntra.ts b/src/lib/server/providers/microsoft_entra/providerEntra.ts new file mode 100644 index 00000000000..d7c6ac69519 --- /dev/null +++ b/src/lib/server/providers/microsoft_entra/providerEntra.ts @@ -0,0 +1,143 @@ +import { OIDConfig, ProviderCookieNames } from "$lib/server/auth"; +import type { Cookies } from "@sveltejs/kit"; +import { env } from "$env/dynamic/private"; +import { dev } from "$app/environment"; +import { addDays } from "date-fns"; +import type { AccessToken, Provider, ProviderParameters } from "$lib/server/providers/providers"; + +interface EntraProviderParameters extends ProviderParameters { + userTid: string; + userOid: string; +} + +export const providerEntra: Provider = { + getAccessToken, + refreshAccessToken: refreshMicrosoftGraphToken, + getUserGroups: getUserEntraGroups, +}; + +async function getAccessToken( + cookies: Cookies, + providerParameters: EntraProviderParameters +): Promise<[AccessToken, EntraProviderParameters]> { + if (!providerParameters.idToken) { + throw new Error("ID Token is initially required to get an access token."); + } + + const urlSearchParams = new URLSearchParams({ + client_id: OIDConfig.CLIENT_ID, + client_secret: OIDConfig.CLIENT_SECRET, + grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer", + assertion: providerParameters.idToken, + scope: "openid profile email offline_access", + requested_token_use: "on_behalf_of", + }); + const response = await fetch( + `https://login.microsoft.com/${providerParameters.userTid}/oauth2/v2.0/token`, + { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: urlSearchParams, + } + ); + const data = await response.json(); + + const accessToken: AccessToken = { + value: data.access_token, + refreshToken: data.refresh_token, + }; + + // Set idToken to undefined after we use it the first time + // This forces the app to use the refreshToken for subsequent requests + // While the idToken could be used for more requests (until it expires), + // it's simpler to just use the refreshToken for all requests after the first + const newProviderParameters = { + ...providerParameters, + idToken: undefined, + }; + + cookies.set(ProviderCookieNames.ACCESS_TOKEN, JSON.stringify(accessToken), { + path: env.APP_BASE, + // So that it works inside the space's iframe + sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none", + secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"), + httpOnly: true, + expires: addDays(new Date(), 1), + }); + + cookies.set(ProviderCookieNames.PROVIDER_PARAMS, JSON.stringify(newProviderParameters), { + path: env.APP_BASE, + // So that it works inside the space's iframe + sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none", + secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"), + httpOnly: true, + expires: addDays(new Date(), 1), + }); + + return [accessToken, newProviderParameters]; +} + +async function refreshMicrosoftGraphToken( + cookies: Cookies, + accessToken: AccessToken, + providerParameters: EntraProviderParameters +): Promise { + const urlSearchParams = new URLSearchParams({ + client_id: OIDConfig.CLIENT_ID, + client_secret: OIDConfig.CLIENT_SECRET, + grant_type: "refresh_token", + refresh_token: accessToken.refreshToken, + scope: "openid profile email offline_access", // offline_access required to get refresh_token + }); + const response = await fetch( + `https://login.microsoft.com/${providerParameters.userTid}/oauth2/v2.0/token`, + { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: urlSearchParams, + } + ); + const data = await response.json(); + + const refreshedAccessToken: AccessToken = { + value: data.access_token, + refreshToken: data.refresh_token, + }; + + cookies.set(ProviderCookieNames.ACCESS_TOKEN, JSON.stringify(refreshedAccessToken), { + path: env.APP_BASE, + // So that it works inside the space's iframe + sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none", + secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"), + httpOnly: true, + expires: addDays(new Date(), 1), + }); + + return refreshedAccessToken; +} + +async function getUserEntraGroups( + accessToken: AccessToken, + providerParameters: EntraProviderParameters +): Promise { + // Get this user's groups via Microsoft Graph API + const response = await fetch( + `https://graph.microsoft.com/v1.0/users/${providerParameters.userOid}/getMemberGroups`, + { + method: "POST", + headers: { + Authorization: `Bearer ${accessToken.value}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + securityEnabledOnly: false, + }), + } + ); + const data = await response.json(); + return data.value; +} diff --git a/src/lib/server/providers/providers.ts b/src/lib/server/providers/providers.ts new file mode 100644 index 00000000000..327ff0ec546 --- /dev/null +++ b/src/lib/server/providers/providers.ts @@ -0,0 +1,45 @@ +import { providerEntra } from "$lib/server/providers/microsoft_entra/providerEntra"; +import type { Cookies } from "@sveltejs/kit"; +import type { Model } from "$lib/types/Model"; + +export interface ProviderParameters { + idToken?: string; +} + +export interface AccessToken { + value: string; + refreshToken: string; +} + +export type Provider = { + // getAccessToken should also set providerParameters.idToken to undefined after use and return the new providerParameters + getAccessToken: ( + cookies: Cookies, + providerParameters: ProviderParametersType + ) => Promise<[AccessToken, ProviderParametersType]>; + refreshAccessToken: ( + cookies: Cookies, + accessToken: AccessToken, + providerParameters: ProviderParametersType + ) => Promise; + getUserGroups: ( + accessToken: AccessToken, + providerParameters: ProviderParametersType + ) => Promise; +}; + +// I'd like to annotate this type but could not figure out how to do so without getting errors +// The value for each entry needs to be Provider +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export const providers: Record> = { + entra: providerEntra, +}; + +export async function getAllowedModels(models: Model[], user_groups: string[]): Promise { + return models + .filter( + (model) => + !model.allowed_groups || model.allowed_groups.some((group) => user_groups.includes(group)) + ) + .map((model) => model.id); +} diff --git a/src/lib/types/Model.ts b/src/lib/types/Model.ts index d69ffbdeb95..01ac9fc6041 100644 --- a/src/lib/types/Model.ts +++ b/src/lib/types/Model.ts @@ -18,6 +18,7 @@ export type Model = Pick< | "multimodal" | "multimodalAcceptedMimetypes" | "unlisted" + | "allowed_groups" | "tools" | "hasInferenceAPI" >; diff --git a/src/routes/+layout.server.ts b/src/routes/+layout.server.ts index 5097054dc41..5fcfdfd7ec3 100644 --- a/src/routes/+layout.server.ts +++ b/src/routes/+layout.server.ts @@ -12,6 +12,7 @@ import { toolFromConfigs } from "$lib/server/tools"; import { MetricsServer } from "$lib/server/metrics"; import type { ToolFront, ToolInputFile } from "$lib/types/Tool"; import { ReviewStatus } from "$lib/types/Review"; +import { getAllowedModels } from "$lib/server/providers/providers"; export const load: LayoutServerLoad = async ({ locals, depends }) => { depends(UrlDependency.ConversationList); @@ -43,6 +44,14 @@ export const load: LayoutServerLoad = async ({ locals, depends }) => { const enableAssistants = env.ENABLE_ASSISTANTS === "true"; + const allowedModelsDefault = models + .filter((model) => !model.allowed_groups) + .map((model) => model.id); + const allowedModels = + locals.user && locals.user.groups + ? await getAllowedModels(models, locals.user.groups) + : allowedModelsDefault; + const assistantActive = !models.map(({ id }) => id).includes(settings?.activeModel ?? ""); const assistant = assistantActive @@ -209,26 +218,28 @@ export const load: LayoutServerLoad = async ({ locals, depends }) => { disableStream: settings?.disableStream ?? DEFAULT_SETTINGS.disableStream, directPaste: settings?.directPaste ?? DEFAULT_SETTINGS.directPaste, }, - models: models.map((model) => ({ - id: model.id, - name: model.name, - websiteUrl: model.websiteUrl, - modelUrl: model.modelUrl, - tokenizer: model.tokenizer, - datasetName: model.datasetName, - datasetUrl: model.datasetUrl, - displayName: model.displayName, - description: model.description, - logoUrl: model.logoUrl, - promptExamples: model.promptExamples, - parameters: model.parameters, - preprompt: model.preprompt, - multimodal: model.multimodal, - multimodalAcceptedMimetypes: model.multimodalAcceptedMimetypes, - tools: model.tools, - unlisted: model.unlisted, - hasInferenceAPI: model.hasInferenceAPI, - })), + models: models + .filter((model) => !model.unlisted) + .filter((model) => allowedModels.includes(model.id)) + .map((model) => ({ + id: model.id, + name: model.name, + websiteUrl: model.websiteUrl, + modelUrl: model.modelUrl, + tokenizer: model.tokenizer, + datasetName: model.datasetName, + datasetUrl: model.datasetUrl, + displayName: model.displayName, + description: model.description, + logoUrl: model.logoUrl, + promptExamples: model.promptExamples, + parameters: model.parameters, + preprompt: model.preprompt, + multimodal: model.multimodal, + multimodalAcceptedMimetypes: model.multimodalAcceptedMimetypes, + tools: model.tools, + hasInferenceAPI: model.hasInferenceAPI, + })), oldModels, tools: [...toolFromConfigs, ...communityTools] .filter((tool) => !tool?.isHidden) diff --git a/src/routes/assistants/+page.svelte b/src/routes/assistants/+page.svelte index 8105cb56a39..a39a579b557 100644 --- a/src/routes/assistants/+page.svelte +++ b/src/routes/assistants/+page.svelte @@ -142,7 +142,7 @@ aria-label="Filter assistants by model" > - {#each data.models.filter((model) => !model.unlisted) as model} + {#each data.models as model} {/each} diff --git a/src/routes/login/+page.server.ts b/src/routes/login/+page.server.ts index b813adc3af9..4201a43df22 100644 --- a/src/routes/login/+page.server.ts +++ b/src/routes/login/+page.server.ts @@ -1,5 +1,5 @@ import { redirect } from "@sveltejs/kit"; -import { getOIDCAuthorizationUrl } from "$lib/server/auth"; +import { getOIDCAuthorizationUrl, responseType } from "$lib/server/auth"; import { base } from "$app/paths"; import { env } from "$env/dynamic/private"; @@ -18,7 +18,12 @@ export const actions = { } const authorizationUrl = await getOIDCAuthorizationUrl( - { redirectURI }, + { + redirectURI, + response_type: responseType, + response_mode: responseType.includes("id_token") ? "form_post" : undefined, + nonce: responseType.includes("id_token") ? locals.sessionId : undefined, + }, { sessionId: locals.sessionId } ); diff --git a/src/routes/login/callback/+page.server.ts b/src/routes/login/callback/+page.server.ts deleted file mode 100644 index f3e4b3fae0b..00000000000 --- a/src/routes/login/callback/+page.server.ts +++ /dev/null @@ -1,72 +0,0 @@ -import { redirect, error } from "@sveltejs/kit"; -import { getOIDCUserData, validateAndParseCsrfToken } from "$lib/server/auth"; -import { z } from "zod"; -import { base } from "$app/paths"; -import { updateUser } from "./updateUser"; -import { env } from "$env/dynamic/private"; -import JSON5 from "json5"; - -const allowedUserEmails = z - .array(z.string().email()) - .optional() - .default([]) - .parse(JSON5.parse(env.ALLOWED_USER_EMAILS)); - -export async function load({ url, locals, cookies, request, getClientAddress }) { - const { error: errorName, error_description: errorDescription } = z - .object({ - error: z.string().optional(), - error_description: z.string().optional(), - }) - .parse(Object.fromEntries(url.searchParams.entries())); - - if (errorName) { - error(400, errorName + (errorDescription ? ": " + errorDescription : "")); - } - - const { code, state, iss } = z - .object({ - code: z.string(), - state: z.string(), - iss: z.string().optional(), - }) - .parse(Object.fromEntries(url.searchParams.entries())); - - const csrfToken = Buffer.from(state, "base64").toString("utf-8"); - - const validatedToken = await validateAndParseCsrfToken(csrfToken, locals.sessionId); - - if (!validatedToken) { - error(403, "Invalid or expired CSRF token"); - } - - const { userData } = await getOIDCUserData( - { redirectURI: validatedToken.redirectUrl }, - code, - iss - ); - - // Filter by allowed user emails - if (allowedUserEmails.length > 0) { - if (!userData.email) { - error(403, "User not allowed: email not returned"); - } - const emailVerified = userData.email_verified ?? true; - if (!emailVerified) { - error(403, "User not allowed: email not verified"); - } - if (!allowedUserEmails.includes(userData.email)) { - error(403, "User not allowed"); - } - } - - await updateUser({ - userData, - locals, - cookies, - userAgent: request.headers.get("user-agent") ?? undefined, - ip: getClientAddress(), - }); - - redirect(302, `${base}/`); -} diff --git a/src/routes/login/callback/+server.ts b/src/routes/login/callback/+server.ts new file mode 100644 index 00000000000..be59fd4f03b --- /dev/null +++ b/src/routes/login/callback/+server.ts @@ -0,0 +1,116 @@ +import { redirect, error, type RequestEvent, type RequestHandler } from "@sveltejs/kit"; +import { getOIDCUserData, ProviderCookieNames, validateAndParseCsrfToken } from "$lib/server/auth"; +import { z } from "zod"; +import { base } from "$app/paths"; +import { updateUser } from "./updateUser"; +import { env } from "$env/dynamic/private"; +import JSON5 from "json5"; + +const allowedUserEmails = z + .array(z.string().email()) + .optional() + .default([]) + .parse(JSON5.parse(env.ALLOWED_USER_EMAILS)); + +async function handleLogin(requestEvent: RequestEvent) { + const { url, locals, cookies, request, getClientAddress } = requestEvent; + + const { error: errorName, error_description: errorDescription } = z + .object({ + error: z.string().optional(), + error_description: z.string().optional(), + }) + .parse(Object.fromEntries(url.searchParams.entries())); + + if (errorName) { + error(400, errorName + (errorDescription ? ": " + errorDescription : "")); + } + + let entries: IterableIterator<[string, string | FormDataEntryValue]>; + if (request.method === "POST") { + const formData = await request.formData(); + entries = formData.entries(); + } else { + entries = url.searchParams.entries(); + } + + const { + code, + state, + iss, + id_token: idToken, + } = z + .object({ + code: z.string(), + state: z.string(), + iss: z.string().optional(), + id_token: z.string().optional(), + }) + .parse(Object.fromEntries(entries)); + + const { csrfToken: csrfTokenBase64, sessionId: loginSessionId } = JSON.parse( + Buffer.from(state, "base64").toString("utf-8") + ); + const csrfToken = Buffer.from(csrfTokenBase64, "base64").toString("utf-8"); + const validatedToken = await validateAndParseCsrfToken(csrfToken, loginSessionId); + + if (!validatedToken) { + error(403, "Invalid or expired CSRF token"); + } + + const { userData } = await getOIDCUserData( + { + redirectURI: validatedToken.redirectUrl, + nonce: idToken ? loginSessionId : undefined, + }, + code, + iss + ); + + // Filter by allowed user emails + if (allowedUserEmails.length > 0) { + if (!userData.email) { + error(403, "User not allowed: email not returned"); + } + const emailVerified = userData.email_verified ?? true; + if (!emailVerified) { + error(403, "User not allowed: email not verified"); + } + if (!allowedUserEmails.includes(userData.email)) { + error(403, "User not allowed"); + } + } + + if (idToken) { + cookies.set( + ProviderCookieNames.PROVIDER_PARAMS, + JSON.stringify({ idToken, userTid: userData.tid, userOid: userData.oid }), + { + httpOnly: true, + secure: true, + sameSite: "none", + path: env.APP_BASE, + } + ); + } + + await updateUser({ + userData, + locals, + cookies, + userAgent: request.headers.get("user-agent") ?? undefined, + ip: getClientAddress(), + }); + + redirect(302, `${base}/`); +} + +export const GET: RequestHandler = async (requestEvent) => { + await handleLogin(requestEvent); + throw redirect(302, `${base}/`); +}; + +export const POST: RequestHandler = async (requestEvent) => { + await handleLogin(requestEvent); + throw redirect(302, `${base}/`); +}; diff --git a/src/routes/logout/+page.server.ts b/src/routes/logout/+page.server.ts index 935846a5da6..f3c08f8401c 100644 --- a/src/routes/logout/+page.server.ts +++ b/src/routes/logout/+page.server.ts @@ -1,20 +1,7 @@ -import { dev } from "$app/environment"; -import { base } from "$app/paths"; -import { env } from "$env/dynamic/private"; -import { collections } from "$lib/server/database"; -import { redirect } from "@sveltejs/kit"; +import { logout } from "$lib/server/auth"; export const actions = { async default({ cookies, locals }) { - await collections.sessions.deleteOne({ sessionId: locals.sessionId }); - - cookies.delete(env.COOKIE_NAME, { - path: "/", - // So that it works inside the space's iframe - sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none", - secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"), - httpOnly: true, - }); - redirect(303, `${base}/`); + await logout(cookies, locals); }, }; diff --git a/src/routes/models/+page.svelte b/src/routes/models/+page.svelte index 0aa77a5d820..d08facf8f33 100644 --- a/src/routes/models/+page.svelte +++ b/src/routes/models/+page.svelte @@ -43,7 +43,7 @@

All models available on {envPublic.PUBLIC_APP_NAME}

- {#each data.models.filter((el) => !el.unlisted) as model, index (model.id)} + {#each data.models as model, index (model.id)}

Models

- {#each data.models.filter((el) => !el.unlisted) as model} + {#each data.models as model}