-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add embedding models configurable, from both transformers.js and TEI (#…
…646) * Add embedding models configurable, from both Xenova and TEI * fix lint and format * Fix bug in sentenceSimilarity * Batches for TEI using /info route * Fix web search disapear when finish searching * Fix lint and format * Add more options for better embedding model usage * Fixing CR issues * Fix websearch disapear in later PR * Fix lint * Fix more minor code CR * Valiadate embeddingModelName field in model config * Add embeddingModel into shared conversation * Fix lint and format * Add default embedding model, and more readme explanation * Fix minor embedding model readme detailed * Update settings.json * Update README.md Co-authored-by: Mishig <[email protected]> * Update README.md Co-authored-by: Mishig <[email protected]> * Apply suggestions from code review Co-authored-by: Mishig <[email protected]> * Resolved more issues * lint * Fix more issues * Fix format * fix small typo * lint * fix default model * Rn `maxSequenceLength` -> `chunkCharLength` * format * add "authorization" example * format --------- Co-authored-by: Mishig <[email protected]> Co-authored-by: Nathan Sarrazin <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]>
- Loading branch information
1 parent
69c0464
commit 3a01622
Showing
18 changed files
with
419 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
src/lib/server/embeddingEndpoints/tei/embeddingEndpoints.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import { z } from "zod"; | ||
import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints"; | ||
import { chunk } from "$lib/utils/chunk"; | ||
|
||
export const embeddingEndpointTeiParametersSchema = z.object({ | ||
weight: z.number().int().positive().default(1), | ||
model: z.any(), | ||
type: z.literal("tei"), | ||
url: z.string().url(), | ||
authorization: z.string().optional(), | ||
}); | ||
|
||
const getModelInfoByUrl = async (url: string, authorization?: string) => { | ||
const { origin } = new URL(url); | ||
|
||
const response = await fetch(`${origin}/info`, { | ||
headers: { | ||
Accept: "application/json", | ||
"Content-Type": "application/json", | ||
...(authorization ? { Authorization: authorization } : {}), | ||
}, | ||
}); | ||
|
||
const json = await response.json(); | ||
return json; | ||
}; | ||
|
||
export async function embeddingEndpointTei( | ||
input: z.input<typeof embeddingEndpointTeiParametersSchema> | ||
): Promise<EmbeddingEndpoint> { | ||
const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input); | ||
|
||
const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url); | ||
const maxBatchSize = Math.min( | ||
max_client_batch_size, | ||
Math.floor(max_batch_tokens / model.chunkCharLength) | ||
); | ||
|
||
return async ({ inputs }) => { | ||
const { origin } = new URL(url); | ||
|
||
const batchesInputs = chunk(inputs, maxBatchSize); | ||
|
||
const batchesResults = await Promise.all( | ||
batchesInputs.map(async (batchInputs) => { | ||
const response = await fetch(`${origin}/embed`, { | ||
method: "POST", | ||
headers: { | ||
Accept: "application/json", | ||
"Content-Type": "application/json", | ||
...(authorization ? { Authorization: authorization } : {}), | ||
}, | ||
body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }), | ||
}); | ||
|
||
const embeddings: Embedding[] = await response.json(); | ||
return embeddings; | ||
}) | ||
); | ||
|
||
const flatAllEmbeddings = batchesResults.flat(); | ||
|
||
return flatAllEmbeddings; | ||
}; | ||
} |
46 changes: 46 additions & 0 deletions
46
src/lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import { z } from "zod"; | ||
import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints"; | ||
import type { Tensor, Pipeline } from "@xenova/transformers"; | ||
import { pipeline } from "@xenova/transformers"; | ||
|
||
export const embeddingEndpointTransformersJSParametersSchema = z.object({ | ||
weight: z.number().int().positive().default(1), | ||
model: z.any(), | ||
type: z.literal("transformersjs"), | ||
}); | ||
|
||
// Use the Singleton pattern to enable lazy construction of the pipeline. | ||
class TransformersJSModelsSingleton { | ||
static instances: Array<[string, Promise<Pipeline>]> = []; | ||
|
||
static async getInstance(modelName: string): Promise<Pipeline> { | ||
const modelPipelineInstance = this.instances.find(([name]) => name === modelName); | ||
|
||
if (modelPipelineInstance) { | ||
const [, modelPipeline] = modelPipelineInstance; | ||
return modelPipeline; | ||
} | ||
|
||
const newModelPipeline = pipeline("feature-extraction", modelName); | ||
this.instances.push([modelName, newModelPipeline]); | ||
|
||
return newModelPipeline; | ||
} | ||
} | ||
|
||
export async function calculateEmbedding(modelName: string, inputs: string[]) { | ||
const extractor = await TransformersJSModelsSingleton.getInstance(modelName); | ||
const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true }); | ||
|
||
return output.tolist(); | ||
} | ||
|
||
export function embeddingEndpointTransformersJS( | ||
input: z.input<typeof embeddingEndpointTransformersJSParametersSchema> | ||
): EmbeddingEndpoint { | ||
const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input); | ||
|
||
return async ({ inputs }) => { | ||
return calculateEmbedding(model.name, inputs); | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import { TEXT_EMBEDDING_MODELS } from "$env/static/private"; | ||
|
||
import { z } from "zod"; | ||
import { sum } from "$lib/utils/sum"; | ||
import { | ||
embeddingEndpoints, | ||
embeddingEndpointSchema, | ||
type EmbeddingEndpoint, | ||
} from "$lib/types/EmbeddingEndpoints"; | ||
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; | ||
|
||
const modelConfig = z.object({ | ||
/** Used as an identifier in DB */ | ||
id: z.string().optional(), | ||
/** Used to link to the model page, and for inference */ | ||
name: z.string().min(1), | ||
displayName: z.string().min(1).optional(), | ||
description: z.string().min(1).optional(), | ||
websiteUrl: z.string().url().optional(), | ||
modelUrl: z.string().url().optional(), | ||
endpoints: z.array(embeddingEndpointSchema).nonempty(), | ||
chunkCharLength: z.number().positive(), | ||
preQuery: z.string().default(""), | ||
prePassage: z.string().default(""), | ||
}); | ||
|
||
// Default embedding model for backward compatibility | ||
const rawEmbeddingModelJSON = | ||
TEXT_EMBEDDING_MODELS || | ||
`[ | ||
{ | ||
"name": "Xenova/gte-small", | ||
"chunkCharLength": 512, | ||
"endpoints": [ | ||
{ "type": "transformersjs" } | ||
] | ||
} | ||
]`; | ||
|
||
const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON)); | ||
|
||
const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({ | ||
...m, | ||
id: m.id || m.name, | ||
}); | ||
|
||
const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({ | ||
...m, | ||
getEndpoint: async (): Promise<EmbeddingEndpoint> => { | ||
if (!m.endpoints) { | ||
return embeddingEndpointTransformersJS({ | ||
type: "transformersjs", | ||
weight: 1, | ||
model: m, | ||
}); | ||
} | ||
|
||
const totalWeight = sum(m.endpoints.map((e) => e.weight)); | ||
|
||
let random = Math.random() * totalWeight; | ||
|
||
for (const endpoint of m.endpoints) { | ||
if (random < endpoint.weight) { | ||
const args = { ...endpoint, model: m }; | ||
|
||
switch (args.type) { | ||
case "tei": | ||
return embeddingEndpoints.tei(args); | ||
case "transformersjs": | ||
return embeddingEndpoints.transformersjs(args); | ||
} | ||
} | ||
|
||
random -= endpoint.weight; | ||
} | ||
|
||
throw new Error(`Failed to select embedding endpoint`); | ||
}, | ||
}); | ||
|
||
export const embeddingModels = await Promise.all( | ||
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint)) | ||
); | ||
|
||
export const defaultEmbeddingModel = embeddingModels[0]; | ||
|
||
const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => { | ||
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]); | ||
}; | ||
|
||
export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => { | ||
return validateEmbeddingModel(_models, "id"); | ||
}; | ||
|
||
export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => { | ||
return validateEmbeddingModel(_models, "name"); | ||
}; | ||
|
||
export type EmbeddingBackendModel = typeof defaultEmbeddingModel; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.