diff --git a/.env b/.env
index c965efba7cc..63370165f6c 100644
--- a/.env
+++ b/.env
@@ -128,4 +128,6 @@ LLM_SUMMERIZATION=true
# PUBLIC_APP_DATA_SHARING=1
# PUBLIC_APP_DISCLAIMER=1
-ENABLE_ASSISTANTS=false #set to true to enable assistants feature
\ No newline at end of file
+ENABLE_ASSISTANTS=false #set to true to enable assistants feature
+ASSISTANTS_GENERATE_AVATAR=true #requires an hf token, uses the model description and name to generate an avatar using a text to image model
+TEXT_TO_IMAGE_MODEL="runwayml/stable-diffusion-v1-5"
diff --git a/src/lib/components/AssistantSettings.svelte b/src/lib/components/AssistantSettings.svelte
index b8ec68f6b0e..a2dd7bb9804 100644
--- a/src/lib/components/AssistantSettings.svelte
+++ b/src/lib/components/AssistantSettings.svelte
@@ -4,10 +4,12 @@
import type { Assistant } from "$lib/types/Assistant";
import { onMount } from "svelte";
- import { enhance } from "$app/forms";
+ import { applyAction, enhance } from "$app/forms";
import { base } from "$app/paths";
import CarbonPen from "~icons/carbon/pen";
import { useSettingsStore } from "$lib/stores/settings";
+ import { page } from "$app/stores";
+ import IconLoading from "./icons/IconLoading.svelte";
type ActionData = {
error: boolean;
@@ -49,6 +51,10 @@
function getError(field: string, returnForm: ActionData) {
return returnForm?.errors.find((error) => error.field === field)?.message ?? "";
}
+
+ let loading = false;
+
+ let generateAvatar = false;
diff --git a/src/lib/utils/generateAvatar.ts b/src/lib/utils/generateAvatar.ts
new file mode 100644
index 00000000000..2cd1b291dcd
--- /dev/null
+++ b/src/lib/utils/generateAvatar.ts
@@ -0,0 +1,24 @@
+import { HF_TOKEN, TEXT_TO_IMAGE_MODEL } from "$env/static/private";
+import { generateFromDefaultEndpoint } from "$lib/server/generateFromDefaultEndpoint";
+import { HfInference } from "@huggingface/inference";
+
+export async function generateAvatar(description?: string, name?: string): Promise {
+ const queryPrompt = `Generate a prompt for an image-generation model for the following:
+Name: ${name}
+Description: ${description}
+`;
+ const imagePrompt = await generateFromDefaultEndpoint({
+ messages: [{ from: "user", content: queryPrompt }],
+ preprompt:
+ "You are an assistant tasked with generating simple image descriptions. The user will ask you for an image, based on the name and a description of what they want, and you should reply with a short, concise, safe, descriptive sentence.",
+ });
+
+ const hf = new HfInference(HF_TOKEN);
+
+ const blob = await hf.textToImage({
+ inputs: imagePrompt,
+ model: TEXT_TO_IMAGE_MODEL,
+ });
+
+ return new File([blob], "avatar.png");
+}
diff --git a/src/lib/utils/timeout.ts b/src/lib/utils/timeout.ts
new file mode 100644
index 00000000000..65d229f155e
--- /dev/null
+++ b/src/lib/utils/timeout.ts
@@ -0,0 +1,6 @@
+export const timeout = (prom: Promise, time: number): Promise => {
+ let timer: NodeJS.Timeout;
+ return Promise.race([prom, new Promise((_r, rej) => (timer = setTimeout(rej, time)))]).finally(
+ () => clearTimeout(timer)
+ );
+};
diff --git a/src/routes/+layout.server.ts b/src/routes/+layout.server.ts
index fcf4069d50a..36682018e9b 100644
--- a/src/routes/+layout.server.ts
+++ b/src/routes/+layout.server.ts
@@ -13,6 +13,8 @@ import {
YDC_API_KEY,
USE_LOCAL_WEBSEARCH,
ENABLE_ASSISTANTS,
+ ASSISTANTS_GENERATE_AVATAR,
+ TEXT_TO_IMAGE_MODEL,
} from "$env/static/private";
import { ObjectId } from "mongodb";
import type { ConvSidebar } from "$lib/types/ConvSidebar";
@@ -161,6 +163,7 @@ export const load: LayoutServerLoad = async ({ locals, depends }) => {
email: locals.user.email,
},
assistant,
+ avatarGeneration: ASSISTANTS_GENERATE_AVATAR === "true" && !!TEXT_TO_IMAGE_MODEL,
enableAssistants,
loginRequired,
loginEnabled: requiresUser,
diff --git a/src/routes/settings/assistants/[assistantId]/edit/+page.server.ts b/src/routes/settings/assistants/[assistantId]/edit/+page.server.ts
index c4be53763a1..ec709c1189a 100644
--- a/src/routes/settings/assistants/[assistantId]/edit/+page.server.ts
+++ b/src/routes/settings/assistants/[assistantId]/edit/+page.server.ts
@@ -7,6 +7,9 @@ import { ObjectId } from "mongodb";
import { z } from "zod";
import sizeof from "image-size";
import { sha256 } from "$lib/utils/sha256";
+import { ASSISTANTS_GENERATE_AVATAR, HF_TOKEN } from "$env/static/private";
+import { generateAvatar } from "$lib/utils/generateAvatar";
+import { timeout } from "$lib/utils/timeout";
const newAsssistantSchema = z.object({
name: z.string().min(1),
@@ -18,6 +21,10 @@ const newAsssistantSchema = z.object({
exampleInput3: z.string().optional(),
exampleInput4: z.string().optional(),
avatar: z.instanceof(File).optional(),
+ generateAvatar: z
+ .literal("on")
+ .optional()
+ .transform((el) => !!el),
});
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise => {
@@ -99,6 +106,29 @@ export const actions: Actions = {
}
hash = await uploadAvatar(parse.data.avatar, assistant._id);
+ } else if (
+ ASSISTANTS_GENERATE_AVATAR === "true" &&
+ HF_TOKEN !== "" &&
+ parse.data.generateAvatar
+ ) {
+ try {
+ const avatar = await timeout(
+ generateAvatar(parse.data.description, parse.data.name),
+ 30000
+ );
+
+ hash = await uploadAvatar(avatar, assistant._id);
+ } catch (err) {
+ return fail(400, {
+ error: true,
+ errors: [
+ {
+ field: "avatar",
+ message: "Avatar generation failed. Try again or disable the feature.",
+ },
+ ],
+ });
+ }
}
const { acknowledged } = await collections.assistants.replaceOne(
diff --git a/src/routes/settings/assistants/new/+page.server.ts b/src/routes/settings/assistants/new/+page.server.ts
index 58e4519e6fd..f5c0ded4c73 100644
--- a/src/routes/settings/assistants/new/+page.server.ts
+++ b/src/routes/settings/assistants/new/+page.server.ts
@@ -7,6 +7,9 @@ import { ObjectId } from "mongodb";
import { z } from "zod";
import sizeof from "image-size";
import { sha256 } from "$lib/utils/sha256";
+import { ASSISTANTS_GENERATE_AVATAR, HF_TOKEN } from "$env/static/private";
+import { timeout } from "$lib/utils/timeout";
+import { generateAvatar } from "$lib/utils/generateAvatar";
const newAsssistantSchema = z.object({
name: z.string().min(1),
@@ -18,6 +21,10 @@ const newAsssistantSchema = z.object({
exampleInput3: z.string().optional(),
exampleInput4: z.string().optional(),
avatar: z.instanceof(File).optional(),
+ generateAvatar: z
+ .literal("on")
+ .optional()
+ .transform((el) => !!el),
});
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise => {
@@ -88,6 +95,29 @@ export const actions: Actions = {
}
hash = await uploadAvatar(parse.data.avatar, newAssistantId);
+ } else if (
+ ASSISTANTS_GENERATE_AVATAR === "true" &&
+ HF_TOKEN !== "" &&
+ parse.data.generateAvatar
+ ) {
+ try {
+ const avatar = await timeout(
+ generateAvatar(parse.data.description, parse.data.name),
+ 30000
+ );
+
+ hash = await uploadAvatar(avatar, newAssistantId);
+ } catch (err) {
+ return fail(400, {
+ error: true,
+ errors: [
+ {
+ field: "avatar",
+ message: "Avatar generation failed. Try again or disable the feature.",
+ },
+ ],
+ });
+ }
}
const { insertedId } = await collections.assistants.insertOne({