From 3a016229728600c4b4f79f1ceca95a9348b8243a Mon Sep 17 00:00:00 2001 From: Michael Fried Date: Tue, 9 Jan 2024 17:36:42 +0200 Subject: [PATCH] Add embedding models configurable, from both transformers.js and TEI (#646) * Add embedding models configurable, from both Xenova and TEI * fix lint and format * Fix bug in sentenceSimilarity * Batches for TEI using /info route * Fix web search disapear when finish searching * Fix lint and format * Add more options for better embedding model usage * Fixing CR issues * Fix websearch disapear in later PR * Fix lint * Fix more minor code CR * Valiadate embeddingModelName field in model config * Add embeddingModel into shared conversation * Fix lint and format * Add default embedding model, and more readme explanation * Fix minor embedding model readme detailed * Update settings.json * Update README.md Co-authored-by: Mishig * Update README.md Co-authored-by: Mishig * Apply suggestions from code review Co-authored-by: Mishig * Resolved more issues * lint * Fix more issues * Fix format * fix small typo * lint * fix default model * Rn `maxSequenceLength` -> `chunkCharLength` * format * add "authorization" example * format --------- Co-authored-by: Mishig Co-authored-by: Nathan Sarrazin Co-authored-by: Mishig Davaadorj --- .env | 12 +++ .env.template | 1 - README.md | 88 ++++++++++++++++- .../components/OpenWebSearchResults.svelte | 4 +- .../tei/embeddingEndpoints.ts | 65 ++++++++++++ .../transformersjs/embeddingEndpoints.ts | 46 +++++++++ src/lib/server/embeddingModels.ts | 99 +++++++++++++++++++ src/lib/server/models.ts | 2 + src/lib/server/sentenceSimilarity.ts | 42 ++++++++ src/lib/server/websearch/runWebSearch.ts | 18 ++-- .../server/websearch/sentenceSimilarity.ts | 52 ---------- src/lib/types/Conversation.ts | 1 + src/lib/types/EmbeddingEndpoints.ts | 41 ++++++++ src/lib/types/SharedConversation.ts | 2 + src/routes/conversation/+server.ts | 6 ++ src/routes/conversation/[id]/+page.svelte | 3 +- src/routes/conversation/[id]/share/+server.ts | 1 + src/routes/login/callback/updateUser.spec.ts | 2 + 18 files changed, 419 insertions(+), 66 deletions(-) create mode 100644 src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts create mode 100644 src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts create mode 100644 src/lib/server/embeddingModels.ts create mode 100644 src/lib/server/sentenceSimilarity.ts delete mode 100644 src/lib/server/websearch/sentenceSimilarity.ts create mode 100644 src/lib/types/EmbeddingEndpoints.ts diff --git a/.env b/.env index cf5ed56e744..01ee88b4cf2 100644 --- a/.env +++ b/.env @@ -46,6 +46,18 @@ CA_PATH=# CLIENT_KEY_PASSWORD=# REJECT_UNAUTHORIZED=true +TEXT_EMBEDDING_MODELS = `[ + { + "name": "Xenova/gte-small", + "displayName": "Xenova/gte-small", + "description": "Local embedding model running on the server.", + "chunkCharLength": 512, + "endpoints": [ + { "type": "transformersjs" } + ] + } +]` + # 'name', 'userMessageToken', 'assistantMessageToken' are required MODELS=`[ { diff --git a/.env.template b/.env.template index a7e33c0b8ff..f49b646eaed 100644 --- a/.env.template +++ b/.env.template @@ -204,7 +204,6 @@ TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2' # "stop": [""] # }}` - APP_BASE="/chat" PUBLIC_ORIGIN=https://huggingface.co PUBLIC_SHARE_PREFIX=https://hf.co/chat diff --git a/README.md b/README.md index 9c0f221252f..6811afd4fea 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,10 @@ A chat interface using open source models, eg OpenAssistant or Llama. It is a Sv 1. [Setup](#setup) 2. [Launch](#launch) 3. [Web Search](#web-search) -4. [Extra parameters](#extra-parameters) -5. [Deploying to a HF Space](#deploying-to-a-hf-space) -6. [Building](#building) +4. [Text Embedding Models](#text-embedding-models) +5. [Extra parameters](#extra-parameters) +6. [Deploying to a HF Space](#deploying-to-a-hf-space) +7. [Building](#building) ## No Setup Deploy @@ -78,10 +79,50 @@ Chat UI features a powerful Web Search feature. It works by: 1. Generating an appropriate search query from the user prompt. 2. Performing web search and extracting content from webpages. -3. Creating embeddings from texts using [transformers.js](https://huggingface.co/docs/transformers.js). Specifically, using [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. +3. Creating embeddings from texts using a text embedding model. 4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance. 5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information). +## Text Embedding Models + +By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model. + +You can customize the embedding model by setting `TEXT_EMBEDDING_MODELS` in your `.env.local` file. For example: + +```env +TEXT_EMBEDDING_MODELS = `[ + { + "name": "Xenova/gte-small", + "displayName": "Xenova/gte-small", + "description": "locally running embedding", + "chunkCharLength": 512, + "endpoints": [ + {"type": "transformersjs"} + ] + }, + { + "name": "intfloat/e5-base-v2", + "displayName": "intfloat/e5-base-v2", + "description": "hosted embedding model", + "chunkCharLength": 768, + "preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq + "prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq + "endpoints": [ + { + "type": "tei", + "url": "http://127.0.0.1:8080/", + "authorization": "TOKEN_TYPE TOKEN" // optional authorization field. Example: "Basic VVNFUjpQQVNT" + } + ] + } +]` +``` + +The required fields are `name`, `chunkCharLength` and `endpoints`. +Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. + +When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model. + ## Extra parameters ### OpenID connect @@ -425,6 +466,45 @@ If you're using a certificate signed by a private CA, you will also need to add If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint. +#### Specific Embedding Model + +A model can use any of the embedding models defined in `.env.local`, (currently used when web searching), +by default it will use the first embedding model, but it can be changed with the field `embeddingModel`: + +```env +TEXT_EMBEDDING_MODELS = `[ + { + "name": "Xenova/gte-small", + "chunkCharLength": 512, + "endpoints": [ + {"type": "transformersjs"} + ] + }, + { + "name": "intfloat/e5-base-v2", + "chunkCharLength": 768, + "endpoints": [ + {"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"}, + {"type": "tei", "url": "http://127.0.0.1:8081/"} + ] + } +]` + +MODELS=`[ + { + "name": "Ollama Mistral", + "chatPromptTemplate": "...", + "embeddingModel": "intfloat/e5-base-v2" + "parameters": { + ... + }, + "endpoints": [ + ... + ] + } +]` +``` + ## Deploying to a HF Space Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run. diff --git a/src/lib/components/OpenWebSearchResults.svelte b/src/lib/components/OpenWebSearchResults.svelte index aac5fa54141..3e8c8190410 100644 --- a/src/lib/components/OpenWebSearchResults.svelte +++ b/src/lib/components/OpenWebSearchResults.svelte @@ -30,8 +30,8 @@ {:else} {/if} - Web search + + Web search
diff --git a/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts new file mode 100644 index 00000000000..17bdc34ae64 --- /dev/null +++ b/src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts @@ -0,0 +1,65 @@ +import { z } from "zod"; +import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints"; +import { chunk } from "$lib/utils/chunk"; + +export const embeddingEndpointTeiParametersSchema = z.object({ + weight: z.number().int().positive().default(1), + model: z.any(), + type: z.literal("tei"), + url: z.string().url(), + authorization: z.string().optional(), +}); + +const getModelInfoByUrl = async (url: string, authorization?: string) => { + const { origin } = new URL(url); + + const response = await fetch(`${origin}/info`, { + headers: { + Accept: "application/json", + "Content-Type": "application/json", + ...(authorization ? { Authorization: authorization } : {}), + }, + }); + + const json = await response.json(); + return json; +}; + +export async function embeddingEndpointTei( + input: z.input +): Promise { + const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input); + + const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url); + const maxBatchSize = Math.min( + max_client_batch_size, + Math.floor(max_batch_tokens / model.chunkCharLength) + ); + + return async ({ inputs }) => { + const { origin } = new URL(url); + + const batchesInputs = chunk(inputs, maxBatchSize); + + const batchesResults = await Promise.all( + batchesInputs.map(async (batchInputs) => { + const response = await fetch(`${origin}/embed`, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + ...(authorization ? { Authorization: authorization } : {}), + }, + body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }), + }); + + const embeddings: Embedding[] = await response.json(); + return embeddings; + }) + ); + + const flatAllEmbeddings = batchesResults.flat(); + + return flatAllEmbeddings; + }; +} diff --git a/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts new file mode 100644 index 00000000000..7cedddcfe15 --- /dev/null +++ b/src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts @@ -0,0 +1,46 @@ +import { z } from "zod"; +import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints"; +import type { Tensor, Pipeline } from "@xenova/transformers"; +import { pipeline } from "@xenova/transformers"; + +export const embeddingEndpointTransformersJSParametersSchema = z.object({ + weight: z.number().int().positive().default(1), + model: z.any(), + type: z.literal("transformersjs"), +}); + +// Use the Singleton pattern to enable lazy construction of the pipeline. +class TransformersJSModelsSingleton { + static instances: Array<[string, Promise]> = []; + + static async getInstance(modelName: string): Promise { + const modelPipelineInstance = this.instances.find(([name]) => name === modelName); + + if (modelPipelineInstance) { + const [, modelPipeline] = modelPipelineInstance; + return modelPipeline; + } + + const newModelPipeline = pipeline("feature-extraction", modelName); + this.instances.push([modelName, newModelPipeline]); + + return newModelPipeline; + } +} + +export async function calculateEmbedding(modelName: string, inputs: string[]) { + const extractor = await TransformersJSModelsSingleton.getInstance(modelName); + const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true }); + + return output.tolist(); +} + +export function embeddingEndpointTransformersJS( + input: z.input +): EmbeddingEndpoint { + const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input); + + return async ({ inputs }) => { + return calculateEmbedding(model.name, inputs); + }; +} diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts new file mode 100644 index 00000000000..13305867d95 --- /dev/null +++ b/src/lib/server/embeddingModels.ts @@ -0,0 +1,99 @@ +import { TEXT_EMBEDDING_MODELS } from "$env/static/private"; + +import { z } from "zod"; +import { sum } from "$lib/utils/sum"; +import { + embeddingEndpoints, + embeddingEndpointSchema, + type EmbeddingEndpoint, +} from "$lib/types/EmbeddingEndpoints"; +import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; + +const modelConfig = z.object({ + /** Used as an identifier in DB */ + id: z.string().optional(), + /** Used to link to the model page, and for inference */ + name: z.string().min(1), + displayName: z.string().min(1).optional(), + description: z.string().min(1).optional(), + websiteUrl: z.string().url().optional(), + modelUrl: z.string().url().optional(), + endpoints: z.array(embeddingEndpointSchema).nonempty(), + chunkCharLength: z.number().positive(), + preQuery: z.string().default(""), + prePassage: z.string().default(""), +}); + +// Default embedding model for backward compatibility +const rawEmbeddingModelJSON = + TEXT_EMBEDDING_MODELS || + `[ + { + "name": "Xenova/gte-small", + "chunkCharLength": 512, + "endpoints": [ + { "type": "transformersjs" } + ] + } +]`; + +const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON)); + +const processEmbeddingModel = async (m: z.infer) => ({ + ...m, + id: m.id || m.name, +}); + +const addEndpoint = (m: Awaited>) => ({ + ...m, + getEndpoint: async (): Promise => { + if (!m.endpoints) { + return embeddingEndpointTransformersJS({ + type: "transformersjs", + weight: 1, + model: m, + }); + } + + const totalWeight = sum(m.endpoints.map((e) => e.weight)); + + let random = Math.random() * totalWeight; + + for (const endpoint of m.endpoints) { + if (random < endpoint.weight) { + const args = { ...endpoint, model: m }; + + switch (args.type) { + case "tei": + return embeddingEndpoints.tei(args); + case "transformersjs": + return embeddingEndpoints.transformersjs(args); + } + } + + random -= endpoint.weight; + } + + throw new Error(`Failed to select embedding endpoint`); + }, +}); + +export const embeddingModels = await Promise.all( + embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint)) +); + +export const defaultEmbeddingModel = embeddingModels[0]; + +const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => { + return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]); +}; + +export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => { + return validateEmbeddingModel(_models, "id"); +}; + +export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => { + return validateEmbeddingModel(_models, "name"); +}; + +export type EmbeddingBackendModel = typeof defaultEmbeddingModel; diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index 0e6e5320b68..9bfdef8a80a 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -12,6 +12,7 @@ import { z } from "zod"; import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints"; import endpointTgi from "./endpoints/tgi/endpointTgi"; import { sum } from "$lib/utils/sum"; +import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels"; import JSON5 from "json5"; @@ -68,6 +69,7 @@ const modelConfig = z.object({ .optional(), multimodal: z.boolean().default(false), unlisted: z.boolean().default(false), + embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(), }); const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS)); diff --git a/src/lib/server/sentenceSimilarity.ts b/src/lib/server/sentenceSimilarity.ts new file mode 100644 index 00000000000..455b25d4d06 --- /dev/null +++ b/src/lib/server/sentenceSimilarity.ts @@ -0,0 +1,42 @@ +import { dot } from "@xenova/transformers"; +import type { EmbeddingBackendModel } from "$lib/server/embeddingModels"; +import type { Embedding } from "$lib/types/EmbeddingEndpoints"; + +// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 +function innerProduct(embeddingA: Embedding, embeddingB: Embedding) { + return 1.0 - dot(embeddingA, embeddingB); +} + +export async function findSimilarSentences( + embeddingModel: EmbeddingBackendModel, + query: string, + sentences: string[], + { topK = 5 }: { topK: number } +): Promise { + const inputs = [ + `${embeddingModel.preQuery}${query}`, + ...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`), + ]; + + const embeddingEndpoint = await embeddingModel.getEndpoint(); + const output = await embeddingEndpoint({ inputs }); + + const queryEmbedding: Embedding = output[0]; + const sentencesEmbeddings: Embedding[] = output.slice(1, inputs.length - 1); + + const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map( + (sentenceEmbedding: Embedding, index: number) => { + return { + distance: innerProduct(queryEmbedding, sentenceEmbedding), + index: index, + }; + } + ); + + distancesFromQuery.sort((a, b) => { + return a.distance - b.distance; + }); + + // Return the indexes of the closest topK sentences + return distancesFromQuery.slice(0, topK).map((item) => item.index); +} diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index 041946bb16b..76b5106fdf4 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -4,13 +4,11 @@ import type { WebSearch, WebSearchSource } from "$lib/types/WebSearch"; import { generateQuery } from "$lib/server/websearch/generateQuery"; import { parseWeb } from "$lib/server/websearch/parseWeb"; import { chunk } from "$lib/utils/chunk"; -import { - MAX_SEQ_LEN as CHUNK_CAR_LEN, - findSimilarSentences, -} from "$lib/server/websearch/sentenceSimilarity"; +import { findSimilarSentences } from "$lib/server/sentenceSimilarity"; import type { Conversation } from "$lib/types/Conversation"; import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { getWebSearchProvider } from "./searchWeb"; +import { defaultEmbeddingModel, embeddingModels } from "$lib/server/embeddingModels"; const MAX_N_PAGES_SCRAPE = 10 as const; const MAX_N_PAGES_EMBED = 5 as const; @@ -63,6 +61,14 @@ export async function runWebSearch( .filter(({ link }) => !DOMAIN_BLOCKLIST.some((el) => link.includes(el))) // filter out blocklist links .slice(0, MAX_N_PAGES_SCRAPE); // limit to first 10 links only + // fetch the model + const embeddingModel = + embeddingModels.find((m) => m.id === conv.embeddingModel) ?? defaultEmbeddingModel; + + if (!embeddingModel) { + throw new Error(`Embedding model ${conv.embeddingModel} not available anymore`); + } + let paragraphChunks: { source: WebSearchSource; text: string }[] = []; if (webSearch.results.length > 0) { appendUpdate("Browsing results"); @@ -78,7 +84,7 @@ export async function runWebSearch( } } const MAX_N_CHUNKS = 100; - const texts = chunk(text, CHUNK_CAR_LEN).slice(0, MAX_N_CHUNKS); + const texts = chunk(text, embeddingModel.chunkCharLength).slice(0, MAX_N_CHUNKS); return texts.map((t) => ({ source: result, text: t })); }); const nestedParagraphChunks = (await Promise.all(promises)).slice(0, MAX_N_PAGES_EMBED); @@ -93,7 +99,7 @@ export async function runWebSearch( appendUpdate("Extracting relevant information"); const topKClosestParagraphs = 8; const texts = paragraphChunks.map(({ text }) => text); - const indices = await findSimilarSentences(prompt, texts, { + const indices = await findSimilarSentences(embeddingModel, prompt, texts, { topK: topKClosestParagraphs, }); webSearch.context = indices.map((idx) => texts[idx]).join(""); diff --git a/src/lib/server/websearch/sentenceSimilarity.ts b/src/lib/server/websearch/sentenceSimilarity.ts deleted file mode 100644 index a877f8e0cd6..00000000000 --- a/src/lib/server/websearch/sentenceSimilarity.ts +++ /dev/null @@ -1,52 +0,0 @@ -import type { Tensor, Pipeline } from "@xenova/transformers"; -import { pipeline, dot } from "@xenova/transformers"; - -// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34 -function innerProduct(tensor1: Tensor, tensor2: Tensor) { - return 1.0 - dot(tensor1.data, tensor2.data); -} - -// Use the Singleton pattern to enable lazy construction of the pipeline. -class PipelineSingleton { - static modelId = "Xenova/gte-small"; - static instance: Promise | null = null; - static async getInstance() { - if (this.instance === null) { - this.instance = pipeline("feature-extraction", this.modelId); - } - return this.instance; - } -} - -// see https://huggingface.co/thenlper/gte-small/blob/d8e2604cadbeeda029847d19759d219e0ce2e6d8/README.md?code=true#L2625 -export const MAX_SEQ_LEN = 512 as const; - -export async function findSimilarSentences( - query: string, - sentences: string[], - { topK = 5 }: { topK: number } -) { - const input = [query, ...sentences]; - - const extractor = await PipelineSingleton.getInstance(); - const output: Tensor = await extractor(input, { pooling: "mean", normalize: true }); - - const queryTensor: Tensor = output[0]; - const sentencesTensor: Tensor = output.slice([1, input.length - 1]); - - const distancesFromQuery: { distance: number; index: number }[] = [...sentencesTensor].map( - (sentenceTensor: Tensor, index: number) => { - return { - distance: innerProduct(queryTensor, sentenceTensor), - index: index, - }; - } - ); - - distancesFromQuery.sort((a, b) => { - return a.distance - b.distance; - }); - - // Return the indexes of the closest topK sentences - return distancesFromQuery.slice(0, topK).map((item) => item.index); -} diff --git a/src/lib/types/Conversation.ts b/src/lib/types/Conversation.ts index 5788ce63fd8..665a688f6b4 100644 --- a/src/lib/types/Conversation.ts +++ b/src/lib/types/Conversation.ts @@ -10,6 +10,7 @@ export interface Conversation extends Timestamps { userId?: User["_id"]; model: string; + embeddingModel: string; title: string; messages: Message[]; diff --git a/src/lib/types/EmbeddingEndpoints.ts b/src/lib/types/EmbeddingEndpoints.ts new file mode 100644 index 00000000000..57cd425c578 --- /dev/null +++ b/src/lib/types/EmbeddingEndpoints.ts @@ -0,0 +1,41 @@ +import { z } from "zod"; +import { + embeddingEndpointTei, + embeddingEndpointTeiParametersSchema, +} from "$lib/server/embeddingEndpoints/tei/embeddingEndpoints"; +import { + embeddingEndpointTransformersJS, + embeddingEndpointTransformersJSParametersSchema, +} from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; + +// parameters passed when generating text +interface EmbeddingEndpointParameters { + inputs: string[]; +} + +export type Embedding = number[]; + +// type signature for the endpoint +export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise; + +export const embeddingEndpointSchema = z.discriminatedUnion("type", [ + embeddingEndpointTeiParametersSchema, + embeddingEndpointTransformersJSParametersSchema, +]); + +type EmbeddingEndpointTypeOptions = z.infer["type"]; + +// generator function that takes in type discrimantor value for defining the endpoint and return the endpoint +export type EmbeddingEndpointGenerator = ( + inputs: Extract, { type: T }> +) => EmbeddingEndpoint | Promise; + +// list of all endpoint generators +export const embeddingEndpoints: { + [Key in EmbeddingEndpointTypeOptions]: EmbeddingEndpointGenerator; +} = { + tei: embeddingEndpointTei, + transformersjs: embeddingEndpointTransformersJS, +}; + +export default embeddingEndpoints; diff --git a/src/lib/types/SharedConversation.ts b/src/lib/types/SharedConversation.ts index 8571f2c3f3a..1996bcc6ff9 100644 --- a/src/lib/types/SharedConversation.ts +++ b/src/lib/types/SharedConversation.ts @@ -7,6 +7,8 @@ export interface SharedConversation extends Timestamps { hash: string; model: string; + embeddingModel: string; + title: string; messages: Message[]; preprompt?: string; diff --git a/src/routes/conversation/+server.ts b/src/routes/conversation/+server.ts index 6452e985d67..2870eddd1bc 100644 --- a/src/routes/conversation/+server.ts +++ b/src/routes/conversation/+server.ts @@ -6,6 +6,7 @@ import { base } from "$app/paths"; import { z } from "zod"; import type { Message } from "$lib/types/Message"; import { models, validateModel } from "$lib/server/models"; +import { defaultEmbeddingModel } from "$lib/server/embeddingModels"; export const POST: RequestHandler = async ({ locals, request }) => { const body = await request.text(); @@ -22,6 +23,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { .parse(JSON.parse(body)); let preprompt = values.preprompt; + let embeddingModel: string; if (values.fromShare) { const conversation = await collections.sharedConversations.findOne({ @@ -35,6 +37,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { title = conversation.title; messages = conversation.messages; values.model = conversation.model; + embeddingModel = conversation.embeddingModel; preprompt = conversation.preprompt; } @@ -44,6 +47,8 @@ export const POST: RequestHandler = async ({ locals, request }) => { throw error(400, "Invalid model"); } + embeddingModel ??= model.embeddingModel ?? defaultEmbeddingModel.name; + if (model.unlisted) { throw error(400, "Can't start a conversation with an unlisted model"); } @@ -59,6 +64,7 @@ export const POST: RequestHandler = async ({ locals, request }) => { preprompt: preprompt === model?.preprompt ? model?.preprompt : preprompt, createdAt: new Date(), updatedAt: new Date(), + embeddingModel: embeddingModel, ...(locals.user ? { userId: locals.user._id } : { sessionId: locals.sessionId }), ...(values.fromShare ? { meta: { fromShareId: values.fromShare } } : {}), }); diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 363d14d6176..ba00e9757a9 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -173,6 +173,7 @@ inputs.forEach(async (el: string) => { try { const update = JSON.parse(el) as MessageUpdate; + if (update.type === "finalAnswer") { finalAnswer = update.text; reader.cancel(); @@ -225,7 +226,7 @@ }); } - // reset the websearchmessages + // reset the websearchMessages webSearchMessages = []; await invalidate(UrlDependency.ConversationList); diff --git a/src/routes/conversation/[id]/share/+server.ts b/src/routes/conversation/[id]/share/+server.ts index e3f81222180..4877de755ad 100644 --- a/src/routes/conversation/[id]/share/+server.ts +++ b/src/routes/conversation/[id]/share/+server.ts @@ -38,6 +38,7 @@ export async function POST({ params, url, locals }) { updatedAt: new Date(), title: conversation.title, model: conversation.model, + embeddingModel: conversation.embeddingModel, preprompt: conversation.preprompt, }; diff --git a/src/routes/login/callback/updateUser.spec.ts b/src/routes/login/callback/updateUser.spec.ts index 54229914571..fefaf8b0f5a 100644 --- a/src/routes/login/callback/updateUser.spec.ts +++ b/src/routes/login/callback/updateUser.spec.ts @@ -6,6 +6,7 @@ import { ObjectId } from "mongodb"; import { DEFAULT_SETTINGS } from "$lib/types/Settings"; import { defaultModel } from "$lib/server/models"; import { findUser } from "$lib/server/auth"; +import { defaultEmbeddingModel } from "$lib/server/embeddingModels"; const userData = { preferred_username: "new-username", @@ -46,6 +47,7 @@ const insertRandomConversations = async (count: number) => { title: "random title", messages: [], model: defaultModel.id, + embeddingModel: defaultEmbeddingModel.id, createdAt: new Date(), updatedAt: new Date(), sessionId: locals.sessionId,