From fc997b1544203e5b49e18252327555296491b74c Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Fri, 22 Dec 2023 13:00:17 +0100 Subject: [PATCH] [Obs AI Assistant] Include `search-*` when recalling documents (#173710) Include `search-*` indices when recalling documents from the knowledge base. General approach: - use the current user, not the internal user. the latter will ~never have access to `search-*` - use `_field_caps` to look for sparse_vector field types - `ml.inference.` is a hard-coded prefix, so we can strip that and `_expanded.predicted_value` to get the original field name - only include documents that have the same model ID as we are using for our regular recalls - if the request fails for whatever reason (which is fine, users might not have access to `search-*`), just ignore it and log it with log level debug - we serialize the entire document - some other non-vectorized metadata can also be important for the LLM to make decisions - sort all documents (kb + `search-*`) by score and return the first 20 - count the amount of tokens, don't send over more than 4000 tokens to the LLM to keep response time down. drop the remaining documents on the floor and log it. --- .../server/functions/index.ts | 1 + .../server/functions/recall.ts | 16 +- .../server/service/client/index.test.ts | 39 ++- .../server/service/client/index.ts | 26 +- .../server/service/index.ts | 5 +- .../service/knowledge_base_service/index.ts | 233 ++++++++++++++---- 6 files changed, 244 insertions(+), 76 deletions(-) diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/index.ts b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts index b25e69c53689c..12075a56942f6 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts @@ -34,6 +34,7 @@ export const registerFunctions: ChatRegistrationFunction = async ({ resources, signal, }; + return client.getKnowledgeBaseStatus().then((response) => { const isReady = response.ready; diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts index 5b6de5b0cc6f1..0624b2f64f970 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts @@ -11,20 +11,18 @@ import dedent from 'dedent'; import * as t from 'io-ts'; import { last, omit } from 'lodash'; import { lastValueFrom } from 'rxjs'; +import { FunctionRegistrationParameters } from '.'; import { MessageRole, type Message } from '../../common/types'; import { concatenateOpenAiChunks } from '../../common/utils/concatenate_openai_chunks'; import { processOpenAiStream } from '../../common/utils/process_openai_stream'; import type { ObservabilityAIAssistantClient } from '../service/client'; -import type { RegisterFunction } from '../service/types'; import { streamIntoObservable } from '../service/util/stream_into_observable'; export function registerRecallFunction({ client, registerFunction, -}: { - client: ObservabilityAIAssistantClient; - registerFunction: RegisterFunction; -}) { + resources, +}: FunctionRegistrationParameters) { registerFunction( { name: 'recall', @@ -99,6 +97,10 @@ export function registerRecallFunction({ queries, }); + resources.logger.debug(`Received ${suggestions.length} suggestions`); + + resources.logger.debug(JSON.stringify(suggestions, null, 2)); + if (suggestions.length === 0) { return { content: [] as unknown as Serializable, @@ -115,6 +117,9 @@ export function registerRecallFunction({ signal, }); + resources.logger.debug(`Received ${relevantDocuments.length} relevant documents`); + resources.logger.debug(JSON.stringify(relevantDocuments, null, 2)); + return { content: relevantDocuments as unknown as Serializable, }; @@ -254,7 +259,6 @@ async function scoreSuggestions({ }) ).pipe(processOpenAiStream(), concatenateOpenAiChunks()) ); - const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response); const { scores } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))( scoreFunctionRequest.message.function_call.arguments diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts index 0349e5ec899f9..7cffaa64d16d3 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts @@ -77,12 +77,23 @@ describe('Observability AI Assistant service', () => { execute: jest.fn(), } as any; - const esClientMock: DeeplyMockedKeys = { + const internalUserEsClientMock: DeeplyMockedKeys = { search: jest.fn(), index: jest.fn(), update: jest.fn(), } as any; + const currentUserEsClientMock: DeeplyMockedKeys = { + search: jest.fn().mockResolvedValue({ + hits: { + hits: [], + }, + }), + fieldCaps: jest.fn().mockResolvedValue({ + fields: [], + }), + } as any; + const knowledgeBaseServiceMock: DeeplyMockedKeys = { recall: jest.fn(), } as any; @@ -91,6 +102,7 @@ describe('Observability AI Assistant service', () => { log: jest.fn(), error: jest.fn(), debug: jest.fn(), + trace: jest.fn(), } as any; const functionClientMock: DeeplyMockedKeys = { @@ -108,7 +120,10 @@ describe('Observability AI Assistant service', () => { return new ObservabilityAIAssistantClient({ actionsClient: actionsClientMock, - esClient: esClientMock, + esClient: { + asInternalUser: internalUserEsClientMock, + asCurrentUser: currentUserEsClientMock, + }, knowledgeBaseService: knowledgeBaseServiceMock, logger: loggerMock, namespace: 'default', @@ -334,7 +349,7 @@ describe('Observability AI Assistant service', () => { type: StreamingChatResponseEventType.ConversationCreate, }); - expect(esClientMock.index).toHaveBeenCalledWith({ + expect(internalUserEsClientMock.index).toHaveBeenCalledWith({ index: '.kibana-observability-ai-assistant-conversations', refresh: true, document: { @@ -386,7 +401,7 @@ describe('Observability AI Assistant service', () => { }); }); - describe('when completig a conversation with an initial conversation id', () => { + describe('when completing a conversation with an initial conversation id', () => { let stream: Readable; let dataHandler: jest.Mock; @@ -402,7 +417,7 @@ describe('Observability AI Assistant service', () => { }; }); - esClientMock.search.mockImplementation(async () => { + internalUserEsClientMock.search.mockImplementation(async () => { return { hits: { hits: [ @@ -430,7 +445,7 @@ describe('Observability AI Assistant service', () => { } as any; }); - esClientMock.update.mockImplementationOnce(async () => { + internalUserEsClientMock.update.mockImplementationOnce(async () => { return {} as any; }); @@ -464,7 +479,7 @@ describe('Observability AI Assistant service', () => { type: StreamingChatResponseEventType.ConversationUpdate, }); - expect(esClientMock.update).toHaveBeenCalledWith({ + expect(internalUserEsClientMock.update).toHaveBeenCalledWith({ refresh: true, index: '.kibana-observability-ai-assistant-conversations', id: 'my-es-document-id', @@ -573,8 +588,8 @@ describe('Observability AI Assistant service', () => { }); it('does not create or update the conversation', async () => { - expect(esClientMock.index).not.toHaveBeenCalled(); - expect(esClientMock.update).not.toHaveBeenCalled(); + expect(internalUserEsClientMock.index).not.toHaveBeenCalled(); + expect(internalUserEsClientMock.update).not.toHaveBeenCalled(); }); }); @@ -816,9 +831,11 @@ describe('Observability AI Assistant service', () => { }, }); - expect(esClientMock.index).toHaveBeenCalled(); + expect(internalUserEsClientMock.index).toHaveBeenCalled(); - expect((esClientMock.index.mock.lastCall![0] as any).document.messages).toEqual([ + expect( + (internalUserEsClientMock.index.mock.lastCall![0] as any).document.messages + ).toEqual([ { '@timestamp': expect.any(String), message: { diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index 9423977428d66..fafb7606a2769 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -53,7 +53,10 @@ export class ObservabilityAIAssistantClient { private readonly dependencies: { actionsClient: PublicMethodsOf; namespace: string; - esClient: ElasticsearchClient; + esClient: { + asInternalUser: ElasticsearchClient; + asCurrentUser: ElasticsearchClient; + }; resources: ObservabilityAIAssistantResourceNames; logger: Logger; user: { @@ -67,7 +70,7 @@ export class ObservabilityAIAssistantClient { private getConversationWithMetaFields = async ( conversationId: string ): Promise | undefined> => { - const response = await this.dependencies.esClient.search({ + const response = await this.dependencies.esClient.asInternalUser.search({ index: this.dependencies.resources.aliases.conversations, query: { bool: { @@ -113,7 +116,7 @@ export class ObservabilityAIAssistantClient { throw notFound(); } - await this.dependencies.esClient.delete({ + await this.dependencies.esClient.asInternalUser.delete({ id: conversation._id, index: conversation._index, refresh: true, @@ -407,7 +410,7 @@ export class ObservabilityAIAssistantClient { }; this.dependencies.logger.debug(`Sending conversation to connector`); - this.dependencies.logger.debug(JSON.stringify(request, null, 2)); + this.dependencies.logger.trace(JSON.stringify(request, null, 2)); const executeResult = await this.dependencies.actionsClient.execute({ actionId: connectorId, @@ -428,17 +431,15 @@ export class ObservabilityAIAssistantClient { ? (executeResult.data as Readable) : (executeResult.data as CreateChatCompletionResponse); - if (response instanceof PassThrough) { - signal.addEventListener('abort', () => { - response.end(); - }); + if (response instanceof Readable) { + signal.addEventListener('abort', () => response.destroy()); } return response as any; }; find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => { - const response = await this.dependencies.esClient.search({ + const response = await this.dependencies.esClient.asInternalUser.search({ index: this.dependencies.resources.aliases.conversations, allow_no_indices: true, query: { @@ -475,7 +476,7 @@ export class ObservabilityAIAssistantClient { this.getConversationUpdateValues(new Date().toISOString()) ); - await this.dependencies.esClient.update({ + await this.dependencies.esClient.asInternalUser.update({ id: document._id, index: document._index, doc: updatedConversation, @@ -547,7 +548,7 @@ export class ObservabilityAIAssistantClient { this.getConversationUpdateValues(new Date().toISOString()) ); - await this.dependencies.esClient.update({ + await this.dependencies.esClient.asInternalUser.update({ id: document._id, index: document._index, doc: { conversation: { title } }, @@ -570,7 +571,7 @@ export class ObservabilityAIAssistantClient { this.getConversationUpdateValues(now) ); - await this.dependencies.esClient.index({ + await this.dependencies.esClient.asInternalUser.index({ index: this.dependencies.resources.aliases.conversations, document: createdConversation, refresh: true, @@ -591,6 +592,7 @@ export class ObservabilityAIAssistantClient { user: this.dependencies.user, queries, contexts, + asCurrentUser: this.dependencies.esClient.asCurrentUser, }); }; diff --git a/x-pack/plugins/observability_ai_assistant/server/service/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/index.ts index 1068f1bd90cc5..3d999b090f9cf 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/index.ts @@ -277,7 +277,10 @@ export class ObservabilityAIAssistantService { return new ObservabilityAIAssistantClient({ actionsClient: await plugins.actions.getActionsClientWithRequest(request), namespace: spaceId, - esClient: coreStart.elasticsearch.client.asInternalUser, + esClient: { + asInternalUser: coreStart.elasticsearch.client.asInternalUser, + asCurrentUser: coreStart.elasticsearch.client.asScoped(request).asCurrentUser, + }, resources: this.resourceNames, logger: this.logger, user: { diff --git a/x-pack/plugins/observability_ai_assistant/server/service/knowledge_base_service/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/knowledge_base_service/index.ts index dd44675c800ca..e4c4efb168d03 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/knowledge_base_service/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/knowledge_base_service/index.ts @@ -12,7 +12,8 @@ import type { Logger } from '@kbn/logging'; import type { TaskManagerStartContract } from '@kbn/task-manager-plugin/server'; import pLimit from 'p-limit'; import pRetry from 'p-retry'; -import { map } from 'lodash'; +import { map, orderBy } from 'lodash'; +import { encode } from 'gpt-tokenizer'; import { ELSER_MODEL_ID, INDEX_QUEUED_DOCUMENTS_TASK_ID, @@ -34,8 +35,8 @@ export interface RecalledEntry { id: string; text: string; score: number | null; - is_correction: boolean; - labels: Record; + is_correction?: boolean; + labels?: Record; } function isAlreadyExistsError(error: Error) { @@ -291,64 +292,204 @@ export class KnowledgeBaseService { } }; + private async recallFromKnowledgeBase({ + queries, + contexts, + namespace, + user, + modelId, + }: { + queries: string[]; + contexts?: string[]; + namespace: string; + user: { name: string }; + modelId: string; + }): Promise { + const query = { + bool: { + should: queries.map((text) => ({ + text_expansion: { + 'ml.tokens': { + model_text: text, + model_id: modelId, + }, + } as unknown as QueryDslTextExpansionQuery, + })), + filter: [ + ...getAccessQuery({ + user, + namespace, + }), + ...getCategoryQuery({ contexts }), + ], + }, + }; + + const response = await this.dependencies.esClient.search< + Pick + >({ + index: [this.dependencies.resources.aliases.kb], + query, + size: 20, + _source: { + includes: ['text', 'is_correction', 'labels'], + }, + }); + + return response.hits.hits.map((hit) => ({ + ...hit._source!, + score: hit._score!, + id: hit._id, + })); + } + + private async recallFromConnectors({ + queries, + asCurrentUser, + modelId, + }: { + queries: string[]; + asCurrentUser: ElasticsearchClient; + modelId: string; + }): Promise { + const ML_INFERENCE_PREFIX = 'ml.inference.'; + + const fieldCaps = await asCurrentUser.fieldCaps({ + index: 'search*', + fields: `${ML_INFERENCE_PREFIX}*`, + allow_no_indices: true, + types: ['sparse_vector'], + filters: '-metadata,-parent', + }); + + const fieldsWithVectors = Object.keys(fieldCaps.fields).map((field) => + field.replace('_expanded.predicted_value', '').replace(ML_INFERENCE_PREFIX, '') + ); + + if (!fieldsWithVectors.length) { + return []; + } + + const esQueries = fieldsWithVectors.flatMap((field) => { + const vectorField = `${ML_INFERENCE_PREFIX}${field}_expanded.predicted_value`; + const modelField = `${ML_INFERENCE_PREFIX}${field}_expanded.model_id`; + + return queries.map((query) => { + return { + bool: { + should: [ + { + text_expansion: { + [vectorField]: { + model_text: query, + model_id: modelId, + }, + } as unknown as QueryDslTextExpansionQuery, + }, + ], + filter: [ + { + term: { + [modelField]: modelId, + }, + }, + ], + }, + }; + }); + }); + + const response = await asCurrentUser.search({ + index: 'search-*', + query: { + bool: { + should: esQueries, + }, + }, + size: 20, + _source: { + exclude: ['_*', 'ml*'], + }, + }); + + return response.hits.hits.map((hit) => ({ + text: JSON.stringify(hit._source), + score: hit._score!, + is_correction: false, + id: hit._id, + })); + } + recall = async ({ user, queries, contexts, namespace, + asCurrentUser, }: { queries: string[]; contexts?: string[]; user: { name: string }; namespace: string; + asCurrentUser: ElasticsearchClient; }): Promise<{ entries: RecalledEntry[]; }> => { - try { - const query = { - bool: { - should: queries.map((text) => ({ - text_expansion: { - 'ml.tokens': { - model_text: text, - model_id: ELSER_MODEL_ID, - }, - } as unknown as QueryDslTextExpansionQuery, - })), - filter: [ - ...getAccessQuery({ - user, - namespace, - }), - ...getCategoryQuery({ contexts }), - ], - }, - }; - - const response = await this.dependencies.esClient.search< - Pick - >({ - index: [this.dependencies.resources.aliases.kb], - query, - size: 20, - _source: { - includes: ['text', 'is_correction', 'labels'], - }, - }); - - return { - entries: response.hits.hits.map((hit) => ({ - ...hit._source!, - score: hit._score!, - id: hit._id, - })), - }; - } catch (error) { - if (isAlreadyExistsError(error)) { - throwKnowledgeBaseNotReady(error.body); + const modelId = ELSER_MODEL_ID; + + const [documentsFromKb, documentsFromConnectors] = await Promise.all([ + this.recallFromKnowledgeBase({ + user, + queries, + contexts, + namespace, + modelId, + }).catch((error) => { + if (isAlreadyExistsError(error)) { + throwKnowledgeBaseNotReady(error.body); + } + throw error; + }), + this.recallFromConnectors({ + asCurrentUser, + queries, + modelId, + }).catch((error) => { + this.dependencies.logger.debug('Error getting data from search indices'); + this.dependencies.logger.debug(error); + return []; + }), + ]); + + const sortedEntries = orderBy( + documentsFromKb.concat(documentsFromConnectors), + 'score', + 'desc' + ).slice(0, 20); + + const MAX_TOKENS = 4000; + + let tokenCount = 0; + + const returnedEntries: RecalledEntry[] = []; + + for (const entry of sortedEntries) { + returnedEntries.push(entry); + tokenCount += encode(entry.text).length; + if (tokenCount >= MAX_TOKENS) { + break; } - throw error; } + + if (returnedEntries.length <= sortedEntries.length) { + this.dependencies.logger.debug( + `Dropped ${sortedEntries.length - returnedEntries.length} entries because of token limit` + ); + } + + return { + entries: returnedEntries, + }; }; getEntries = async ({