Skip to content

Commit

Permalink
Add support for individual model permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonstrutz11 authored and zacps committed Dec 20, 2024
1 parent 6dd3ae0 commit f9fd4b2
Show file tree
Hide file tree
Showing 18 changed files with 499 additions and 119 deletions.
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,49 @@ We currently support [IDEFICS](https://huggingface.co/blog/idefics) (hosted on T
}
```

#### Group-based Model Permissions

If [logging in with OpenID](#openid-connect) via a supported provider, then user groups can be used in combination with the `allowed_groups` field for each model to show/hide models to users based on their group membership.

For all providers, see the following. Then, see additional instructions for your provider below.

1. Add `PROVIDER: "<provider-name-here>"` to your `.env.local` (you will enter the actual provider name later). Also, add `groups` to the `OPENID_CONFIG.SCOPES` field in your `.env.local` file:
```env
OPENID_CONFIG=`{
// rest of OPENID_CONFIG here
PROVIDER: "<provider-name-here>",
SCOPES: "openid profile groups",
// rest of OPENID_CONFIG here
}`
```

2. Use the `allowed_groups` parameter for each model to specify which group(s) should have access to that model. If not specified, all users will be able to access the model.

> [!WARNING]
> The first model in your `.env.local` file is considered the "default" model and should be available to all users, so we strongly recommend against setting `allowed_groups` for this model.
> Note that during development, it is common to have `APP_BASE=""` in your `.env.local` - however, due to the cookies created by using a provider, this value should not be empty (e.g. setting `APP_BASE="/"` in `.env.local` would work).
#### Provider: Microsoft Entra

In order to enable use of [Microsoft Entra Security Groups](https://learn.microsoft.com/en-us/entra/fundamentals/concept-learn-about-groups) to show/hide models, do the following:

1. Replace `<provider-name-here>` with `entra` in `.env.local`.

2. `allowed_groups` for each model in `.env.local` should be a list of Microsoft Entra **Group IDs** (not group names), e.g.:

```env
{
// rest of the model config here
"allowed_groups": ["123abcde-1234-abcd-cdef-1234567890ab", "abcde123-abcd-1234-cdef-abcdef123456"]
}
```

3. Finally, configure your app in Microsoft Entra so that the app can access user groups via the MS Graph API:
- [Add groups claim](https://learn.microsoft.com/en-gb/entra/identity-platform/optional-claims?tabs=appui#configure-groups-optional-claims) to your app
- [Enable ID Tokens](https://learn.microsoft.com/en-us/entra/identity-platform/v2-protocols-oidc#enable-id-tokens) for your app


#### Running your own models using a custom endpoint

If you want to, instead of hitting models on the Hugging Face Inference API, you can run your own models locally.
Expand Down
2 changes: 1 addition & 1 deletion src/app.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ declare global {
// interface Error {}
interface Locals {
sessionId: string;
user?: User & { logoutDisabled?: boolean };
user?: User & { logoutDisabled?: boolean; groups?: string[] };
}

interface Error {
Expand Down
57 changes: 56 additions & 1 deletion src/hooks.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { sha256 } from "$lib/utils/sha256";
import { addWeeks } from "date-fns";
import { checkAndRunMigrations } from "$lib/migrations/migrations";
import { building } from "$app/environment";
import { logout, OIDConfig, ProviderCookieNames } from "$lib/server/auth";
import { type AccessToken, providers } from "$lib/server/providers/providers";
import { logger } from "$lib/server/logger";
import { AbortedGenerations } from "$lib/server/abortedGenerations";
import { MetricsServer } from "$lib/server/metrics";
Expand Down Expand Up @@ -229,7 +231,11 @@ export const handle: Handle = async ({ event, resolve }) => {
...(envPublic.PUBLIC_ORIGIN ? [new URL(envPublic.PUBLIC_ORIGIN).host] : []),
];

if (!validOrigins.includes(new URL(origin).host)) {
// origin is null for some reason when the POST request callback comes from an auth provider like MS entra so we skip this check (CSRF token is still validated)
if (
event.url.pathname !== `${base}/login/callback` &&
!validOrigins.includes(new URL(origin).host)
) {
return errorResponse(403, "Invalid referer for POST request");
}
}
Expand Down Expand Up @@ -278,6 +284,55 @@ export const handle: Handle = async ({ event, resolve }) => {
}
}

// Get user groups for allowed models
if (OIDConfig.PROVIDER && OIDConfig.SCOPES.includes("groups")) {
const provider = providers[OIDConfig.PROVIDER];
const session_exists = event.cookies.get(env.COOKIE_NAME) !== undefined;

let accessToken: AccessToken = JSON.parse(
event.cookies.get(ProviderCookieNames.ACCESS_TOKEN)?.toString() || "{}"
);
let providerParameters = JSON.parse(
event.cookies.get(ProviderCookieNames.PROVIDER_PARAMS)?.toString() || "{}"
);

// If user is logged in, get/refresh access token and use it to retrieve user groups
if (event.locals.user) {
// Get access token upon login with id token
if (accessToken && providerParameters.idToken) {
[accessToken, providerParameters] = await provider.getAccessToken(
event.cookies,
providerParameters
);
event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters);
}
// Refresh access token on subsequent requests
else if (accessToken.refreshToken && providerParameters.userTid) {
accessToken = await provider.refreshAccessToken(
event.cookies,
accessToken,
providerParameters
);
event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters);
}
// Logout user automatically if session exists but access token and/or provider params cookies have expired
else if (session_exists) {
event.locals.user.groups = undefined;
await logout(event.cookies, event.locals);
}
}
} else if (OIDConfig.SCOPES.includes("groups")) {
return errorResponse(
500,
"'groups' has been set in OPENID_CONFIG.SCOPES, but OPENID_CONFIG.PROVIDER is undefined in .env file"
);
} else if (OIDConfig.PROVIDER) {
return errorResponse(
500,
"OPENID_CONFIG.PROVIDER has been set, but 'groups' scope not set in OPENID_CONFIG.SCOPES in .env file"
);
}

let replaced = false;

const response = await resolve(event, {
Expand Down
2 changes: 1 addition & 1 deletion src/lib/components/AssistantSettings.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
bind:value={modelId}
>
{#each models.filter((model) => !model.unlisted) as model}
{#each models as model}
<option value={model.id}>{model.displayName}</option>
{/each}
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
Expand Down
2 changes: 1 addition & 1 deletion src/lib/components/NavMenu.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
older: "Older",
} as const;
const nModels: number = $page.data.models.filter((el: Model) => !el.unlisted).length;
const nModels: number = $page.data.models.length;
</script>

<div class="sticky top-0 flex flex-none items-center justify-between px-1.5 py-3.5 max-sm:pt-0">
Expand Down
51 changes: 48 additions & 3 deletions src/lib/server/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ import {
type TokenSet,
custom,
} from "openid-client";
import { redirect } from "@sveltejs/kit";
import { addHours, addWeeks } from "date-fns";
import { env } from "$env/dynamic/private";
import { sha256 } from "$lib/utils/sha256";
import { z } from "zod";
import { dev } from "$app/environment";
import { base } from "$app/paths";
import type { Cookies } from "@sveltejs/kit";
import { collections } from "$lib/server/database";
import JSON5 from "json5";
import { logger } from "$lib/server/logger";

export interface OIDCSettings {
redirectURI: string;
response_type?: string;
response_mode?: string | undefined;
nonce?: string | undefined;
}

export interface OIDCUserInfo {
Expand All @@ -34,6 +39,7 @@ export const OIDConfig = z
.object({
CLIENT_ID: stringWithDefault(env.OPENID_CLIENT_ID),
CLIENT_SECRET: stringWithDefault(env.OPENID_CLIENT_SECRET),
PROVIDER: stringWithDefault(env.OPENID_PROVIDER || ""),
PROVIDER_URL: stringWithDefault(env.OPENID_PROVIDER_URL),
SCOPES: stringWithDefault(env.OPENID_SCOPES),
NAME_CLAIM: stringWithDefault(env.OPENID_NAME_CLAIM).refine(
Expand All @@ -46,8 +52,15 @@ export const OIDConfig = z
})
.parse(JSON5.parse(env.OPENID_CONFIG || "{}"));

export const ProviderCookieNames = {
ACCESS_TOKEN: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-access-token" : "",
PROVIDER_PARAMS: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-params" : "",
};

export const requiresUser = !!OIDConfig.CLIENT_ID && !!OIDConfig.CLIENT_SECRET;

export const responseType = OIDConfig.SCOPES.includes("groups") ? "code id_token" : "code";

const sameSite = z
.enum(["lax", "none", "strict"])
.default(dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none")
Expand Down Expand Up @@ -108,7 +121,7 @@ async function getOIDCClient(settings: OIDCSettings): Promise<BaseClient> {
client_id: OIDConfig.CLIENT_ID,
client_secret: OIDConfig.CLIENT_SECRET,
redirect_uris: [settings.redirectURI],
response_types: ["code"],
response_types: ["code", "id_token"],
[custom.clock_tolerance]: OIDConfig.TOLERANCE || undefined,
id_token_signed_response_alg: OIDConfig.ID_TOKEN_SIGNED_RESPONSE_ALG || undefined,
};
Expand All @@ -131,8 +144,13 @@ export async function getOIDCAuthorizationUrl(

return client.authorizationUrl({
scope: OIDConfig.SCOPES,
state: csrfToken,
state: Buffer.from(JSON.stringify({ csrfToken, sessionId: params.sessionId })).toString(
"base64"
),
resource: OIDConfig.RESOURCE || undefined,
response_type: settings.response_type,
response_mode: settings.response_mode,
nonce: settings.nonce,
});
}

Expand All @@ -142,7 +160,11 @@ export async function getOIDCUserData(
iss?: string
): Promise<OIDCUserInfo> {
const client = await getOIDCClient(settings);
const token = await client.callback(settings.redirectURI, { code, iss });
const token = await client.callback(
settings.redirectURI,
{ code, iss },
{ nonce: settings.nonce }
);
const userData = await client.userinfo(token);

return { token, userData };
Expand Down Expand Up @@ -175,3 +197,26 @@ export async function validateAndParseCsrfToken(
}
return null;
}

export async function logout(cookies: Cookies, locals: App.Locals) {
await collections.sessions.deleteOne({ sessionId: locals.sessionId });

const cookie_names = [env.COOKIE_NAME];
if (ProviderCookieNames.ACCESS_TOKEN) {
cookie_names.push(ProviderCookieNames.ACCESS_TOKEN);
}
if (ProviderCookieNames.PROVIDER_PARAMS) {
cookie_names.push(ProviderCookieNames.PROVIDER_PARAMS);
}

for (const cookie_name of cookie_names) {
cookies.delete(cookie_name, {
path: env.APP_BASE,
// So that it works inside the space's iframe
sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none",
secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"),
httpOnly: true,
});
}
redirect(303, `${base}/`);
}
1 change: 1 addition & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ const modelConfig = z.object({
multimodal: z.boolean().default(false),
multimodalAcceptedMimetypes: z.array(z.string()).optional(),
tools: z.boolean().default(false),
allowed_groups: z.array(z.string()).optional(),
unlisted: z.boolean().default(false),
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
/** Used to enable/disable system prompt usage */
Expand Down
Loading

0 comments on commit f9fd4b2

Please sign in to comment.