diff --git a/server/src/handlers/api/v1/openai/chat.handler.ts b/server/src/handlers/api/v1/openai/chat.handler.ts index 4a883080..48055ff7 100644 --- a/server/src/handlers/api/v1/openai/chat.handler.ts +++ b/server/src/handlers/api/v1/openai/chat.handler.ts @@ -1,5 +1,5 @@ import type { FastifyRequest, FastifyReply } from "fastify"; -import type { OpenaiRequestType } from "./type" +import type { OpenaiRequestType } from "./type"; import { getModelInfo } from "../../../../utils/get-model-info"; import { embeddings } from "../../../../utils/embeddings"; import { Document } from "langchain/document"; @@ -8,36 +8,44 @@ import { DialoqbaseHybridRetrival } from "../../../../utils/hybrid"; import { DialoqbaseVectorStore } from "../../../../utils/store"; import { createChatModel } from "../bot/playground/chat.service"; import { createChain } from "../../../../chain"; -import { openaiNonStreamResponse, openaiStreamResponse } from "./openai-response"; +import { + openaiNonStreamResponse, + openaiStreamResponse, +} from "./openai-response"; import { groupOpenAiMessages } from "./other"; import { nextTick } from "../../../../utils/nextTick"; - export const createChatCompletionHandler = async ( request: FastifyRequest, reply: FastifyReply ) => { try { - const { - model, - messages - } = request.body; + const { model, messages } = request.body; const prisma = request.server.prisma; + let knowledge_base_ids: string[] = []; + + const kb = request.body?.tools?.find( + (e) => e.type === "knowledge_base" && e.value.length > 0 + ); + if (kb) { + knowledge_base_ids = kb.value; + } + console.log(knowledge_base_ids) const bot = await prisma.bot.findFirst({ where: { OR: [ { - id: model + id: model, }, { - publicId: model - } + publicId: model, + }, ], user_id: request.user.is_admin ? undefined : request.user.user_id, }, - }) + }); if (!bot) { return reply.status(404).send({ @@ -45,12 +53,11 @@ export const createChatCompletionHandler = async ( message: "Bot not found", type: "not_found", param: "model", - code: "bot_not_found" - } + code: "bot_not_found", + }, }); } - const embeddingInfo = await getModelInfo({ prisma, model: bot.embedding, @@ -63,12 +70,11 @@ export const createChatCompletionHandler = async ( message: "Embedding not found", type: "not_found", param: "embedding", - code: "embedding_not_found" - } + code: "embedding_not_found", + }, }); } - const embeddingModel = embeddings( embeddingInfo.model_provider!.toLowerCase(), embeddingInfo.model_id, @@ -87,8 +93,8 @@ export const createChatCompletionHandler = async ( message: "Model not found", type: "not_found", param: "model", - code: "model_not_found" - } + code: "model_not_found", + }, }); } @@ -100,6 +106,7 @@ export const createChatCompletionHandler = async ( retriever = new DialoqbaseHybridRetrival(embeddingModel, { botId: bot.id, sourceId: null, + knowledge_base_ids, callbacks: [ { handleRetrieverEnd(documents) { @@ -114,11 +121,12 @@ export const createChatCompletionHandler = async ( { botId: bot.id, sourceId: null, + knowledge_base_ids, + } ); - retriever = vectorstore.asRetriever({ - }); + retriever = vectorstore.asRetriever({}); } const streamedModel = createChatModel( @@ -140,48 +148,37 @@ export const createChatCompletionHandler = async ( if (!request.body.stream) { const res = await chain.invoke({ question: messages[messages.length - 1].content, - chat_history: groupOpenAiMessages( - messages - ), - }) - + chat_history: groupOpenAiMessages(messages), + }); - return reply.status(200).send(openaiNonStreamResponse( - res, - bot.name - )) + return reply.status(200).send(openaiNonStreamResponse(res, bot.name)); } const stream = await chain.stream({ question: messages[messages.length - 1].content, - chat_history: groupOpenAiMessages( - messages - ), - }) + chat_history: groupOpenAiMessages(messages), + }); reply.raw.setHeader("Content-Type", "text/event-stream"); for await (const token of stream) { reply.sse({ - data: openaiStreamResponse( - token || "", - bot.name - ) + data: openaiStreamResponse(token || "", bot.name), }); } reply.sse({ - data: "[DONE]\n\n" - }) + data: "[DONE]\n\n", + }); await nextTick(); return reply.raw.end(); } catch (error) { - console.log(error) + console.log(error); return reply.status(500).send({ error: { message: error.message, type: "internal_server_error", param: null, - code: "internal_server_error" - } + code: "internal_server_error", + }, }); } -} \ No newline at end of file +}; diff --git a/server/src/handlers/api/v1/openai/type.ts b/server/src/handlers/api/v1/openai/type.ts index 8505c5c5..cafc2d4e 100644 --- a/server/src/handlers/api/v1/openai/type.ts +++ b/server/src/handlers/api/v1/openai/type.ts @@ -7,5 +7,9 @@ export interface OpenaiRequestType { model: string; stream: boolean; temperature: number; + tools?: { + type?: "knowledge_base", + value?: string[] + }[] } } \ No newline at end of file diff --git a/server/src/schema/api/v1/openai/index.ts b/server/src/schema/api/v1/openai/index.ts index cc4bb9d9..76e57809 100644 --- a/server/src/schema/api/v1/openai/index.ts +++ b/server/src/schema/api/v1/openai/index.ts @@ -30,6 +30,24 @@ export const createChatCompletionSchema: FastifySchema = { }, temperature: { type: "number" + }, + tools: { + type: "array", + items: { + type: "object", + required: ["type"], + properties: { + type: { + type: "string" + }, + value: { + type: "array", + items: { + type: "string" + } + } + } + } } } } diff --git a/server/src/utils/hybrid.ts b/server/src/utils/hybrid.ts index 5dc028df..28c2cec0 100644 --- a/server/src/utils/hybrid.ts +++ b/server/src/utils/hybrid.ts @@ -1,5 +1,5 @@ import { Document } from "langchain/document"; -import { PrismaClient } from "@prisma/client"; +import { Prisma, PrismaClient } from "@prisma/client"; import { Embeddings } from "langchain/embeddings/base"; import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers"; import { CallbackManagerForRetrieverRun, Callbacks } from "langchain/callbacks"; @@ -8,6 +8,7 @@ const prisma = new PrismaClient(); export interface DialoqbaseLibArgs extends BaseRetrieverInput { botId: string; sourceId: string | null; + knowledge_base_ids?: string[]; } interface SearchEmbeddingsResponse { @@ -30,14 +31,36 @@ export class DialoqbaseHybridRetrival extends BaseRetriever { embeddings: Embeddings; similarityK = 5; keywordK = 4; + knowledge_base_ids: string[]; constructor(embeddings: Embeddings, args: DialoqbaseLibArgs) { super(args); this.botId = args.botId; this.sourceId = args.sourceId; this.embeddings = embeddings; + this.knowledge_base_ids = args.knowledge_base_ids || []; + } + async similaritySearchWithSelectedKBs( + query: number[], + k: number, + knowledgeBaseIds: string[] + ) { + const vector = `[${query?.join(",")}]`; + const results = await prisma.$queryRaw` + SELECT "sourceId", "content", "metadata", + (embedding <=> ${vector}::vector) AS distance + FROM "BotDocument" + WHERE "sourceId" IN (${Prisma.join(knowledgeBaseIds)}) + ORDER BY distance ASC + LIMIT ${k} + ` + return results as { + sourceId: string; + content: string; + metadata: object; + distance: number; + }[]; } - protected async similaritySearch( query: string, k: number, @@ -53,20 +76,32 @@ export class DialoqbaseHybridRetrival extends BaseRetriever { id: bot_id, }, }); - const data = await prisma.$queryRaw` - SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${k}::int) - `; + let result: (number | Document)[][]; + const match_count = botInfo?.noOfDocumentsToRetrieve || k; + + if (this.knowledge_base_ids && this.knowledge_base_ids.length > 0) { + const data = await this.similaritySearchWithSelectedKBs(embeddedQuery, match_count, this.knowledge_base_ids); + result = data.map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + 1 - resp.distance, + ]); + } else { + const data = await prisma.$queryRaw` + SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${match_count}::int) + `; + result = (data as SearchEmbeddingsResponse[]).map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + resp.similarity, + ]); + } + - const result: [Document, number, number][] = ( - data as SearchEmbeddingsResponse[] - ).map((resp) => [ - new Document({ - metadata: resp.metadata, - pageContent: resp.content, - }), - resp.similarity * 10, - resp.id, - ]); let internetSearchResults = []; if (botInfo.internetSearchEnabled) { internetSearchResults = await searchInternet(this.embeddings, { diff --git a/server/src/utils/store.ts b/server/src/utils/store.ts index 2fd0987e..5f3ea75e 100644 --- a/server/src/utils/store.ts +++ b/server/src/utils/store.ts @@ -1,5 +1,5 @@ import { Document } from "@langchain/core/documents"; -import { PrismaClient } from "@prisma/client"; +import { Prisma, PrismaClient } from "@prisma/client"; import { Embeddings } from "@langchain/core/embeddings"; import { VectorStore } from "@langchain/core/vectorstores"; import { Callbacks } from "langchain/callbacks"; @@ -8,6 +8,7 @@ const prisma = new PrismaClient(); export interface DialoqbaseLibArgs { botId: string; sourceId: string | null; + knowledge_base_ids?: string[]; } export function removeUUID(filename: string) { return filename.replace(/^\w{8}-\w{4}-\w{4}-\w{4}-\w{12}-/, ""); @@ -24,12 +25,14 @@ export class DialoqbaseVectorStore extends VectorStore { botId: string; sourceId: string | null; declare embeddings: Embeddings; + knowledge_base_ids: string[]; constructor(embeddings: Embeddings, args: DialoqbaseLibArgs) { super(embeddings, args); this.botId = args.botId; this.sourceId = args.sourceId; this.embeddings = embeddings; + this.knowledge_base_ids = args.knowledge_base_ids || []; } async addVectors(vectors: number[][], documents: Document[]): Promise { const rows = vectors.map((embedding, idx) => ({ @@ -94,16 +97,39 @@ export class DialoqbaseVectorStore extends VectorStore { return instance; } + + async similaritySearchWithSelectedKBs( + query: number[], + k: number, + knowledgeBaseIds: string[] + ) { + const vector = `[${query?.join(",")}]`; + const results = await prisma.$queryRaw` + SELECT "sourceId", "content", "metadata", + (embedding <=> ${vector}::vector) AS distance + FROM "BotDocument" + WHERE "sourceId" IN (${Prisma.join(knowledgeBaseIds)}) + ORDER BY distance ASC + LIMIT ${k} + ` + return results as { + sourceId: string; + content: string; + metadata: object; + distance: number; + }[]; + } + async similaritySearchVectorWithScore( query: number[], k: number, filter?: this["FilterType"] | undefined, - originalQuery?: string | undefined + originalQuery?: string | undefined, ): Promise<[Document>, number][]> { if (!query) { return []; } - const vector = `[${query?.join(",")}]`; + const bot_id = this.botId; const botInfo = await prisma.bot.findFirst({ @@ -117,17 +143,30 @@ export class DialoqbaseVectorStore extends VectorStore { const semanticSearchSimilarityScore = botInfo?.semanticSearchSimilarityScore || "none"; - const data = await prisma.$queryRaw` - SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${match_count}::int) - `; + let result: (number | Document)[][]; - const result = (data as SearchEmbeddingsResponse[]).map((resp) => [ - new Document({ - metadata: resp.metadata, - pageContent: resp.content, - }), - resp.similarity, - ]); + if (this.knowledge_base_ids && this.knowledge_base_ids.length > 0) { + const data = await this.similaritySearchWithSelectedKBs(query, match_count, this.knowledge_base_ids); + result = data.map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + 1 - resp.distance, + ]); + } else { + const vector = `[${query?.join(",")}]`; + const data = await prisma.$queryRaw` + SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${match_count}::int) + `; + result = (data as SearchEmbeddingsResponse[]).map((resp) => [ + new Document({ + metadata: resp.metadata, + pageContent: resp.content, + }), + resp.similarity, + ]); + } let internetSearchResults = []; if (botInfo.internetSearchEnabled) {