From 4af739c1cefdd2304eaa0e3ba7128eb247c48233 Mon Sep 17 00:00:00 2001 From: Henry Date: Sat, 31 Aug 2024 00:59:56 +0100 Subject: [PATCH] add custom retriever --- .../CustomRetriever/CustomRetriever.ts | 159 ++++++++++++++++++ .../CustomRetriever/customRetriever.svg | 1 + 2 files changed, 160 insertions(+) create mode 100644 packages/components/nodes/retrievers/CustomRetriever/CustomRetriever.ts create mode 100644 packages/components/nodes/retrievers/CustomRetriever/customRetriever.svg diff --git a/packages/components/nodes/retrievers/CustomRetriever/CustomRetriever.ts b/packages/components/nodes/retrievers/CustomRetriever/CustomRetriever.ts new file mode 100644 index 00000000000..3a459b51d36 --- /dev/null +++ b/packages/components/nodes/retrievers/CustomRetriever/CustomRetriever.ts @@ -0,0 +1,159 @@ +import { get } from 'lodash' +import { Document } from '@langchain/core/documents' +import { VectorStore, VectorStoreRetriever, VectorStoreRetrieverInput } from '@langchain/core/vectorstores' +import { INode, INodeData, INodeParams, INodeOutputsValue } from '../../../src/Interface' +import { handleEscapeCharacters } from '../../../src' + +const defaultReturnFormat = '{{context}}\nSource: {{metadata.source}}' + +class CustomRetriever_Retrievers implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + outputs: INodeOutputsValue[] + + constructor() { + this.label = 'Custom Retriever' + this.name = 'customRetriever' + this.version = 1.0 + this.type = 'CustomRetriever' + this.icon = 'customRetriever.svg' + this.category = 'Retrievers' + this.description = 'Return results based on predefined format' + this.baseClasses = [this.type, 'BaseRetriever'] + this.inputs = [ + { + label: 'Vector Store', + name: 'vectorStore', + type: 'VectorStore' + }, + { + label: 'Query', + name: 'query', + type: 'string', + description: 'Query to retrieve documents from retriever. If not specified, user question will be used', + optional: true, + acceptVariable: true + }, + { + label: 'Result Format', + name: 'resultFormat', + type: 'string', + rows: 4, + description: + 'Format to return the results in. Use {{context}} to insert the pageContent of the document and {{metadata.key}} to insert metadata values.', + default: defaultReturnFormat + }, + { + label: 'Top K', + name: 'topK', + description: 'Number of top results to fetch. Default to vector store topK', + placeholder: '4', + type: 'number', + additionalParams: true, + optional: true + } + ] + this.outputs = [ + { + label: 'Custom Retriever', + name: 'retriever', + baseClasses: this.baseClasses + }, + { + label: 'Document', + name: 'document', + description: 'Array of document objects containing metadata and pageContent', + baseClasses: ['Document', 'json'] + }, + { + label: 'Text', + name: 'text', + description: 'Concatenated string from pageContent of documents', + baseClasses: ['string', 'json'] + } + ] + } + + async init(nodeData: INodeData, input: string): Promise { + const vectorStore = nodeData.inputs?.vectorStore as VectorStore + const query = nodeData.inputs?.query as string + const topK = nodeData.inputs?.topK as string + const resultFormat = nodeData.inputs?.resultFormat as string + + const output = nodeData.outputs?.output as string + + const retriever = CustomRetriever.fromVectorStore(vectorStore, { + resultFormat, + topK: topK ? parseInt(topK, 10) : (vectorStore as any)?.k ?? 4 + }) + + if (output === 'retriever') return retriever + else if (output === 'document') return await retriever.getRelevantDocuments(query ? query : input) + else if (output === 'text') { + let finaltext = '' + + const docs = await retriever.getRelevantDocuments(query ? query : input) + + for (const doc of docs) finaltext += `${doc.pageContent}\n` + + return handleEscapeCharacters(finaltext, false) + } + + return retriever + } +} + +type RetrieverInput = Omit, 'k'> & { + topK?: number + resultFormat?: string +} + +class CustomRetriever extends VectorStoreRetriever { + resultFormat: string + topK = 4 + + constructor(input: RetrieverInput) { + super(input) + this.topK = input.topK ?? this.topK + this.resultFormat = input.resultFormat ?? this.resultFormat + } + + async getRelevantDocuments(query: string): Promise { + const results = await this.vectorStore.similaritySearchWithScore(query, this.topK, this.filter) + + const finalDocs: Document[] = [] + for (const result of results) { + let res = this.resultFormat.replace(/{{context}}/g, result[0].pageContent) + res = replaceMetadata(res, result[0].metadata) + finalDocs.push( + new Document({ + pageContent: res, + metadata: result[0].metadata + }) + ) + } + return finalDocs + } + + static fromVectorStore(vectorStore: V, options: Omit, 'vectorStore'>) { + return new this({ ...options, vectorStore }) + } +} + +function replaceMetadata(template: string, metadata: Record): string { + const metadataRegex = /{{metadata\.([\w.]+)}}/g + + return template.replace(metadataRegex, (match, path) => { + const value = get(metadata, path) + return value !== undefined ? String(value) : match + }) +} + +module.exports = { nodeClass: CustomRetriever_Retrievers } diff --git a/packages/components/nodes/retrievers/CustomRetriever/customRetriever.svg b/packages/components/nodes/retrievers/CustomRetriever/customRetriever.svg new file mode 100644 index 00000000000..aa0e9c3eb48 --- /dev/null +++ b/packages/components/nodes/retrievers/CustomRetriever/customRetriever.svg @@ -0,0 +1 @@ + \ No newline at end of file