Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Assistants] Use textToImage task for avatar generation #662

Merged
merged 11 commits into from
Jan 15, 2024
4 changes: 3 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,6 @@ LLM_SUMMERIZATION=true
# PUBLIC_APP_DATA_SHARING=1
# PUBLIC_APP_DISCLAIMER=1

ENABLE_ASSISTANTS=false #set to true to enable assistants feature
ENABLE_ASSISTANTS=false #set to true to enable assistants feature
ASSISTANTS_GENERATE_AVATAR=true #requires an hf token, uses the model description and name to generate an avatar using a text to image model
TEXT_TO_IMAGE_MODEL="runwayml/stable-diffusion-v1-5"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: but we can use newer models like https://huggingface.co/latent-consistency/lcm-lora-ssd-1b (assuming that they are hosted on inference API) that produce higher quality images faster

  1. It is higher quality because, it is based on SD XL
  2. It produces images faster, because it is smaller model (ssd) & uses fewer sampling (lcm-lora)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's fair! I picked sd1.5 over sdxl because I assumed it'd be faster but if we have faster models available by all mean let's change it 😁

I think we should prio speed>image quality for this feature, most of the time this will be seen as only a small thumbnail, so if you have good model recommendations, feel free!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for example, https://huggingface.co/latent-consistency/lcm-lora-sdv1-5 would be faster than runway/sd-1.5. https://huggingface.co/latent-consistency/lcm-lora-sdxl might still be faster than runway/sd-1.5. Maybe @patil-suraj @sayakpaul can confirm?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's SDXL, it will not be SD speed. You can try out Segmind's SSD-1B.

Copy link
Collaborator

@mishig25 mishig25 Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So https://huggingface.co/latent-consistency/lcm-lora-ssd-1b would be the best, in terms of speed/quality tradeoff?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say so. But, would also consider playing SD Turbo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the license of SD Turbo ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried latent-consistency/lcm-lora-ssd-1bz but it doesn't seem to load, I think it would need to be pinned in the API on our side if we go for it, just a note not necessarily a blocker 😁

33 changes: 30 additions & 3 deletions src/lib/components/AssistantSettings.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import type { Assistant } from "$lib/types/Assistant";

import { onMount } from "svelte";
import { enhance } from "$app/forms";
import { applyAction, enhance } from "$app/forms";
import { base } from "$app/paths";
import CarbonPen from "~icons/carbon/pen";
import { useSettingsStore } from "$lib/stores/settings";
import { page } from "$app/stores";
import IconLoading from "./icons/IconLoading.svelte";

type ActionData = {
error: boolean;
Expand Down Expand Up @@ -49,13 +51,16 @@
function getError(field: string, returnForm: ActionData) {
return returnForm?.errors.find((error) => error.field === field)?.message ?? "";
}

let loading = false;
</script>

<form
method="POST"
class="h-full w-full overflow-x-clip"
enctype="multipart/form-data"
use:enhance={async ({ formData }) => {
loading = true;
const avatar = formData.get("avatar");

if (avatar && typeof avatar !== "string" && avatar.size > 0 && compress) {
Expand All @@ -67,6 +72,11 @@
formData.set("avatar", resizedImage);
});
}

return async ({ result }) => {
loading = false;
await applyAction(result);
};
}}
>
{#if assistant}
Expand Down Expand Up @@ -125,6 +135,12 @@
<span class="text-xs text-gray-500 hover:underline">Click to upload</span>
{/if}
<p class="text-xs text-red-500">{getError("avatar", form)}</p>
{#if (!files || !files[0]) && $page.data.avatarGeneration && !assistant?.avatar}
nsarrazin marked this conversation as resolved.
Show resolved Hide resolved
<label class="text-xs text-gray-500">
<input type="checkbox" name="generateAvatar" class="text-xs text-gray-500" />
Generate avatar using a text-to-image model.
nsarrazin marked this conversation as resolved.
Show resolved Hide resolved
</label>
{/if}
</label>

<label>
Expand Down Expand Up @@ -220,8 +236,19 @@
class="rounded-full bg-gray-200 px-8 py-2 font-semibold text-gray-600">Cancel</a
>

<button type="submit" class="rounded-full bg-black px-8 py-2 font-semibold text-white md:px-20"
>{assistant ? "Save" : "Create"}</button
<button
type="submit"
disabled={loading}
aria-disabled={loading}
class="rounded-full bg-black px-8 py-2 font-semibold md:px-20"
class:bg-gray-200={loading}
class:text-gray-600={loading}
class:text-white={!loading}
>
{assistant ? "Save" : "Create"}
{#if loading}
<IconLoading classNames="ml-2 h-min" />
{/if}
</button>
</div>
</form>
24 changes: 24 additions & 0 deletions src/lib/utils/generateAvatar.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { HF_TOKEN, TEXT_TO_IMAGE_MODEL } from "$env/static/private";
import { generateFromDefaultEndpoint } from "$lib/server/generateFromDefaultEndpoint";
import { HfInference } from "@huggingface/inference";

export async function generateAvatar(description?: string, name?: string): Promise<File> {
const queryPrompt = `Generate a prompt for an image-generation model for the following:
Name: ${name}
Description: ${description}
`;
const imagePrompt = await generateFromDefaultEndpoint({
nsarrazin marked this conversation as resolved.
Show resolved Hide resolved
messages: [{ from: "user", content: queryPrompt }],
preprompt:
"You are an assistant tasked with generating simple image descriptions. The user will ask you for an image, based on the name and a description of what they want, and you should reply with a short, concise, safe, descriptive sentence.",
});

const hf = new HfInference(HF_TOKEN);

const blob = await hf.textToImage({
inputs: imagePrompt,
model: TEXT_TO_IMAGE_MODEL,
});

return new File([blob], "avatar.png");
}
6 changes: 6 additions & 0 deletions src/lib/utils/timeout.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export const timeout = <T>(prom: Promise<T>, time: number): Promise<T> => {
let timer: NodeJS.Timeout;
return Promise.race([prom, new Promise<T>((_r, rej) => (timer = setTimeout(rej, time)))]).finally(
() => clearTimeout(timer)
);
};
3 changes: 3 additions & 0 deletions src/routes/+layout.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import {
YDC_API_KEY,
USE_LOCAL_WEBSEARCH,
ENABLE_ASSISTANTS,
ASSISTANTS_GENERATE_AVATAR,
TEXT_TO_IMAGE_MODEL,
} from "$env/static/private";
import { ObjectId } from "mongodb";
import type { ConvSidebar } from "$lib/types/ConvSidebar";
Expand Down Expand Up @@ -161,6 +163,7 @@ export const load: LayoutServerLoad = async ({ locals, depends }) => {
email: locals.user.email,
},
assistant,
avatarGeneration: ASSISTANTS_GENERATE_AVATAR === "true" && TEXT_TO_IMAGE_MODEL !== "",
nsarrazin marked this conversation as resolved.
Show resolved Hide resolved
enableAssistants,
loginRequired,
loginEnabled: requiresUser,
Expand Down
30 changes: 30 additions & 0 deletions src/routes/settings/assistants/[assistantId]/edit/+page.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import { ObjectId } from "mongodb";
import { z } from "zod";
import sizeof from "image-size";
import { sha256 } from "$lib/utils/sha256";
import { ASSISTANTS_GENERATE_AVATAR, HF_TOKEN } from "$env/static/private";
import { generateAvatar } from "$lib/utils/generateAvatar";
import { timeout } from "$lib/utils/timeout";

const newAsssistantSchema = z.object({
name: z.string().min(1),
Expand All @@ -18,6 +21,10 @@ const newAsssistantSchema = z.object({
exampleInput3: z.string().optional(),
exampleInput4: z.string().optional(),
avatar: z.instanceof(File).optional(),
generateAvatar: z
.literal("on")
.optional()
.transform((el) => !!el),
});

const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
Expand Down Expand Up @@ -99,6 +106,29 @@ export const actions: Actions = {
}

hash = await uploadAvatar(parse.data.avatar, assistant._id);
} else if (
ASSISTANTS_GENERATE_AVATAR === "true" &&
HF_TOKEN !== "" &&
parse.data.generateAvatar
) {
try {
const avatar = await timeout(
generateAvatar(parse.data.description, parse.data.name),
30000
);

hash = await uploadAvatar(avatar, assistant._id);
} catch (err) {
return fail(400, {
error: true,
errors: [
{
field: "avatar",
message: "Avatar generation failed. Try again or disable the feature.",
},
],
});
}
}

const { acknowledged } = await collections.assistants.replaceOne(
Expand Down
30 changes: 30 additions & 0 deletions src/routes/settings/assistants/new/+page.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import { ObjectId } from "mongodb";
import { z } from "zod";
import sizeof from "image-size";
import { sha256 } from "$lib/utils/sha256";
import { ASSISTANTS_GENERATE_AVATAR, HF_TOKEN } from "$env/static/private";
import { timeout } from "$lib/utils/timeout";
import { generateAvatar } from "$lib/utils/generateAvatar";

const newAsssistantSchema = z.object({
name: z.string().min(1),
Expand All @@ -18,6 +21,10 @@ const newAsssistantSchema = z.object({
exampleInput3: z.string().optional(),
exampleInput4: z.string().optional(),
avatar: z.instanceof(File).optional(),
generateAvatar: z
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
.literal("on")
.optional()
.transform((el) => !!el),
});

const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
Expand Down Expand Up @@ -88,6 +95,29 @@ export const actions: Actions = {
}

hash = await uploadAvatar(parse.data.avatar, newAssistantId);
} else if (
ASSISTANTS_GENERATE_AVATAR === "true" &&
HF_TOKEN !== "" &&
parse.data.generateAvatar
) {
try {
const avatar = await timeout(
generateAvatar(parse.data.description, parse.data.name),
30000
);

hash = await uploadAvatar(avatar, newAssistantId);
} catch (err) {
return fail(400, {
error: true,
errors: [
{
field: "avatar",
message: "Avatar generation failed. Try again or disable the feature.",
},
],
});
}
}

const { insertedId } = await collections.assistants.insertOne({
Expand Down