From 13382875e99e8c97f4574d86eca07cac3be9edfc Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Sat, 15 Jun 2024 12:16:50 -0400 Subject: [PATCH] [Obs AI Assistant] Expose recall function as API (#185058) Exposes a `POST /internal/observability_ai_assistant/chat/recall` endpoint for [Investigate UI ](https://github.com/elastic/kibana/pull/183293). It is mostly just moving stuff around, some small refactorings and a new way to generate short ids. Previously we were using indexes for scoring suggestions, we are now generating a short but unique id (ie 4-5 chars) which generates a fairly unique token which strengthens the relationship between the id and the object but still allows for quick output. LLMs are slow to generate UUIDs, but indexes are very generic and the LLM might not pay a lot of attention to it. --- .../common/index.ts | 4 +- .../common/types.ts | 7 +- .../concatenate_chat_completion_chunks.ts | 3 +- .../common/utils/short_id_table.test.ts | 48 ++++ .../common/utils/short_id_table.ts | 56 +++++ ...throw_serialized_chat_completion_errors.ts | 2 +- .../common/utils/until_aborted.ts | 24 ++ .../public/components/assistant_avatar.tsx | 4 +- .../public/hooks/use_abortable_async.ts | 6 +- .../public/index.ts | 25 +- .../public/mock.tsx | 7 +- .../public/service/complete.test.ts | 18 +- .../public/service/complete.ts | 30 +-- .../public/service/create_chat_service.ts | 176 +++++++------- .../public/storybook_mock.tsx | 4 +- .../public/types.ts | 23 +- .../utils/create_screen_context_action.ts | 6 +- .../server/functions/context.ts | 226 ++---------------- .../get_relevant_field_names.ts | 32 ++- .../server/functions/index.ts | 8 +- .../server/routes/chat/route.ts | 200 ++++++++++------ .../client/adapters/process_openai_stream.ts | 34 ++- .../server/service/client/index.test.ts | 23 +- .../server/service/client/index.ts | 10 +- .../client/operators/continue_conversation.ts | 20 +- .../service/knowledge_base_service/index.ts | 8 +- ...t_system_message_from_instructions.test.ts | 13 +- .../get_system_message_from_instructions.ts | 31 ++- .../recall}/parse_suggestion_scores.test.ts | 28 +-- .../recall}/parse_suggestion_scores.ts | 8 +- .../server/utils/recall/recall_and_score.ts | 89 +++++++ .../utils/recall/retrieve_suggestions.ts | 24 ++ .../server/utils/recall/score_suggestions.ts | 164 +++++++++++++ .../server/utils/recall/types.ts | 10 + .../public/functions/visualize_esql.test.tsx | 35 +-- 35 files changed, 897 insertions(+), 509 deletions(-) create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.test.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/until_aborted.ts rename x-pack/plugins/observability_solution/observability_ai_assistant/server/{functions => utils/recall}/parse_suggestion_scores.test.ts (71%) rename x-pack/plugins/observability_solution/observability_ai_assistant/server/{functions => utils/recall}/parse_suggestion_scores.ts (77%) create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/recall_and_score.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/retrieve_suggestions.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/score_suggestions.ts create mode 100644 x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/types.ts diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/index.ts index e29aa4c2e1bc9..cfb4987862535 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/index.ts @@ -8,7 +8,7 @@ export type { Message, Conversation, KnowledgeBaseEntry } from './types'; export type { ConversationCreateRequest } from './types'; export { KnowledgeBaseEntryRole, MessageRole } from './types'; -export type { FunctionDefinition } from './functions/types'; +export type { FunctionDefinition, CompatibleJSONSchema } from './functions/types'; export { FunctionVisibility } from './functions/function_visibility'; export { VISUALIZE_ESQL_USER_INTENTIONS, @@ -49,3 +49,5 @@ export { concatenateChatCompletionChunks } from './utils/concatenate_chat_comple export { DEFAULT_LANGUAGE_OPTION, LANGUAGE_OPTIONS } from './ui_settings/language_options'; export { isSupportedConnectorType } from './connectors'; + +export { ShortIdTable } from './utils/short_id_table'; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts index ea6c754193341..bd1a284b0d363 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/types.ts @@ -95,6 +95,7 @@ export interface KnowledgeBaseEntry { export interface UserInstruction { doc_id: string; text: string; + system?: boolean; } export type UserInstructionOrPlainText = string | UserInstruction; @@ -109,7 +110,7 @@ export interface ObservabilityAIAssistantScreenContextRequest { actions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>; } -export type ScreenContextActionRespondFunction = ({}: { +export type ScreenContextActionRespondFunction = ({}: { args: TArguments; signal: AbortSignal; connectorId: string; @@ -117,7 +118,7 @@ export type ScreenContextActionRespondFunction = ({} messages: Message[]; }) => Promise; -export interface ScreenContextActionDefinition { +export interface ScreenContextActionDefinition { name: string; description: string; parameters?: CompatibleJSONSchema; @@ -137,6 +138,6 @@ export interface ObservabilityAIAssistantScreenContext { description: string; value: any; }>; - actions?: ScreenContextActionDefinition[]; + actions?: Array>; starterPrompts?: StarterPrompt[]; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/concatenate_chat_completion_chunks.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/concatenate_chat_completion_chunks.ts index 8686ff93afb34..bead0974b91a3 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/concatenate_chat_completion_chunks.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/concatenate_chat_completion_chunks.ts @@ -31,6 +31,7 @@ export const concatenateChatCompletionChunks = acc.message.content += message.content ?? ''; acc.message.function_call.name += message.function_call?.name ?? ''; acc.message.function_call.arguments += message.function_call?.arguments ?? ''; + return cloneDeep(acc); }, { @@ -43,6 +44,6 @@ export const concatenateChatCompletionChunks = }, role: MessageRole.Assistant, }, - } + } as ConcatenatedMessage ) ); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.test.ts new file mode 100644 index 0000000000000..784cf67530652 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.test.ts @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { ShortIdTable } from './short_id_table'; + +describe('shortIdTable', () => { + it('generates at least 10k unique ids consistently', () => { + const ids = new Set(); + + const table = new ShortIdTable(); + + let i = 10_000; + while (i--) { + const id = table.take(String(i)); + ids.add(id); + } + + expect(ids.size).toBe(10_000); + }); + + it('returns the original id based on the generated id', () => { + const table = new ShortIdTable(); + + const idsByOriginal = new Map(); + + let i = 100; + while (i--) { + const id = table.take(String(i)); + idsByOriginal.set(String(i), id); + } + + expect(idsByOriginal.size).toBe(100); + + expect(() => { + Array.from(idsByOriginal.entries()).forEach(([originalId, shortId]) => { + const returnedOriginalId = table.lookup(shortId); + if (returnedOriginalId !== originalId) { + throw Error( + `Expected shortId ${shortId} to return ${originalId}, but ${returnedOriginalId} was returned instead` + ); + } + }); + }).not.toThrow(); + }); +}); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.ts new file mode 100644 index 0000000000000..30049452ddf51 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/short_id_table.ts @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +const ALPHABET = 'abcdefghijklmnopqrstuvwxyz'; + +function generateShortId(size: number): string { + let id = ''; + let i = size; + while (i--) { + const index = Math.floor(Math.random() * ALPHABET.length); + id += ALPHABET[index]; + } + return id; +} + +const MAX_ATTEMPTS_AT_LENGTH = 100; + +export class ShortIdTable { + private byShortId: Map = new Map(); + private byOriginalId: Map = new Map(); + + constructor() {} + + take(originalId: string) { + if (this.byOriginalId.has(originalId)) { + return this.byOriginalId.get(originalId)!; + } + + let uniqueId: string | undefined; + let attemptsAtLength = 0; + let length = 4; + while (!uniqueId) { + const nextId = generateShortId(length); + attemptsAtLength++; + if (!this.byShortId.has(nextId)) { + uniqueId = nextId; + } else if (attemptsAtLength >= MAX_ATTEMPTS_AT_LENGTH) { + attemptsAtLength = 0; + length++; + } + } + + this.byShortId.set(uniqueId, originalId); + this.byOriginalId.set(originalId, uniqueId); + + return uniqueId; + } + + lookup(shortId: string) { + return this.byShortId.get(shortId); + } +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/throw_serialized_chat_completion_errors.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/throw_serialized_chat_completion_errors.ts index 2c23109a1bac0..e137a4cce1f75 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/throw_serialized_chat_completion_errors.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/throw_serialized_chat_completion_errors.ts @@ -21,7 +21,7 @@ export function throwSerializedChatCompletionErrors< return (source$) => source$.pipe( tap((event) => { - // de-serialise error + // de-serialize error if (event.type === StreamingChatResponseEventType.ChatCompletionError) { const code = event.error.code ?? ChatCompletionErrorCode.InternalError; const message = event.error.message; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/until_aborted.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/until_aborted.ts new file mode 100644 index 0000000000000..d5e3ff9e18bd4 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/common/utils/until_aborted.ts @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable, OperatorFunction, takeUntil } from 'rxjs'; +import { AbortError } from '@kbn/kibana-utils-plugin/common'; + +export function untilAborted(signal: AbortSignal): OperatorFunction { + return (source$) => { + const signal$ = new Observable((subscriber) => { + if (signal.aborted) { + subscriber.error(new AbortError()); + } + signal.addEventListener('abort', () => { + subscriber.error(new AbortError()); + }); + }); + + return source$.pipe(takeUntil(signal$)); + }; +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/components/assistant_avatar.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/components/assistant_avatar.tsx index 64ac351bad0a4..c9b0b21e70bcd 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/components/assistant_avatar.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/components/assistant_avatar.tsx @@ -10,6 +10,7 @@ export interface AssistantAvatarProps { size?: keyof typeof sizeMap; children?: ReactNode; css?: React.SVGProps['css']; + className?: string; } export const sizeMap = { @@ -20,7 +21,7 @@ export const sizeMap = { xs: 16, }; -export function AssistantAvatar({ size = 's', css }: AssistantAvatarProps) { +export function AssistantAvatar({ size = 's', css, className }: AssistantAvatarProps) { const sizePx = sizeMap[size]; return ( diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_abortable_async.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_abortable_async.ts index afd776dc13990..433ca877b0f62 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_abortable_async.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/hooks/use_abortable_async.ts @@ -40,6 +40,7 @@ export function useAbortableAsync( if (clearValueOnNext) { setValue(undefined); + setError(undefined); } try { @@ -47,7 +48,10 @@ export function useAbortableAsync( if (isPromise(response)) { setLoading(true); response - .then(setValue) + .then((nextValue) => { + setError(undefined); + setValue(nextValue); + }) .catch((err) => { setValue(undefined); setError(err); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts index 7f0c1f8bb4c09..2e604b59fc7ab 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/index.ts @@ -5,7 +5,6 @@ * 2.0. */ import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/public'; -export type { CompatibleJSONSchema } from '../common/functions/types'; import { ObservabilityAIAssistantPlugin } from './plugin'; import type { @@ -18,6 +17,7 @@ import type { ObservabilityAIAssistantChatService, RegisterRenderFunctionDefinition, RenderFunction, + DiscoveredDataset, } from './types'; export type { @@ -27,6 +27,7 @@ export type { ObservabilityAIAssistantChatService, RegisterRenderFunctionDefinition, RenderFunction, + DiscoveredDataset, }; export { aiAssistantCapabilities } from '../common/capabilities'; @@ -59,15 +60,27 @@ export { VISUALIZE_ESQL_USER_INTENTIONS, } from '../common/functions/visualize_esql'; -export { isSupportedConnectorType } from '../common'; -export { FunctionVisibility } from '../common'; +export { + isSupportedConnectorType, + FunctionVisibility, + MessageRole, + KnowledgeBaseEntryRole, + concatenateChatCompletionChunks, + StreamingChatResponseEventType, +} from '../common'; +export type { + CompatibleJSONSchema, + Conversation, + Message, + KnowledgeBaseEntry, + FunctionDefinition, + ChatCompletionChunkEvent, + ShortIdTable, +} from '../common'; export type { TelemetryEventTypeWithPayload } from './analytics'; export { ObservabilityAIAssistantTelemetryEventType } from './analytics/telemetry_event_type'; -export type { Conversation, Message, KnowledgeBaseEntry } from '../common'; -export { MessageRole, KnowledgeBaseEntryRole } from '../common'; - export { createFunctionRequestMessage } from '../common/utils/create_function_request_message'; export { createFunctionResponseMessage } from '../common/utils/create_function_response_message'; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx index 4775ad1b551b1..31907f54c49bf 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/mock.tsx @@ -8,7 +8,10 @@ import { i18n } from '@kbn/i18n'; import { noop } from 'lodash'; import React from 'react'; import { Observable, of } from 'rxjs'; -import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete'; +import type { + ChatCompletionChunkEvent, + StreamingChatResponseEventWithoutError, +} from '../common/conversation_complete'; import { MessageRole, ScreenContextActionDefinition } from '../common/types'; import type { ObservabilityAIAssistantAPIClient } from './api'; import type { @@ -21,7 +24,7 @@ import { buildFunctionElasticsearch, buildFunctionServiceSummary } from './utils export const mockChatService: ObservabilityAIAssistantChatService = { sendAnalyticsEvent: noop, - chat: (options) => new Observable(), + chat: (options) => new Observable(), complete: (options) => new Observable(), getFunctions: () => [buildFunctionElasticsearch(), buildFunctionServiceSummary()], renderFunction: (name) => ( diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts index 421770cf415c7..4665f9b7b486b 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.test.ts @@ -15,6 +15,8 @@ import { ChatCompletionError, MessageAddEvent, createInternalServerError, + createConversationNotFoundError, + StreamingChatResponseEventWithoutError, } from '../../common'; import type { ObservabilityAIAssistantChatService } from '../types'; import { complete } from './complete'; @@ -45,7 +47,7 @@ const messages: Message[] = [ const createLlmResponse = ( chunks: Array<{ content: string; function_call?: { name: string; arguments: string } }> -): StreamingChatResponseEvent[] => { +): StreamingChatResponseEventWithoutError[] => { const id = v4(); const message = chunks.reduce( (prev, current) => { @@ -61,7 +63,7 @@ const createLlmResponse = ( } ); - const events: StreamingChatResponseEvent[] = [ + const events: StreamingChatResponseEventWithoutError[] = [ ...chunks.map((msg) => ({ id, message: msg, @@ -108,20 +110,12 @@ describe('complete', () => { describe('when an error is emitted', () => { beforeEach(() => { - requestCallback.mockImplementation(() => - of({ - type: StreamingChatResponseEventType.ChatCompletionError, - error: { - message: 'Not found', - code: ChatCompletionErrorCode.NotFoundError, - }, - }) - ); + requestCallback.mockImplementation(() => throwError(() => createConversationNotFoundError())); }); it('the observable errors out', async () => { await expect(async () => await lastValueFrom(callComplete())).rejects.toThrowError( - 'Not found' + 'Conversation not found' ); await expect(async () => await lastValueFrom(callComplete())).rejects.toBeInstanceOf( diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts index 8d9efd033bad7..90a8f16639ed6 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/complete.ts @@ -20,19 +20,16 @@ import { import { MessageRole, StreamingChatResponseEventType, - type BufferFlushEvent, type ConversationCreateEvent, type ConversationUpdateEvent, type Message, type MessageAddEvent, - type StreamingChatResponseEvent, type StreamingChatResponseEventWithoutError, } from '../../common'; -import { ObservabilityAIAssistantScreenContext } from '../../common/types'; +import type { ObservabilityAIAssistantScreenContext } from '../../common/types'; import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message'; -import { throwSerializedChatCompletionErrors } from '../../common/utils/throw_serialized_chat_completion_errors'; import type { ObservabilityAIAssistantAPIClientRequestParamsOf } from '../api'; -import { ObservabilityAIAssistantChatService } from '../types'; +import type { ObservabilityAIAssistantChatService } from '../types'; import { createPublicFunctionResponseError } from '../utils/create_function_response_error'; export function complete( @@ -46,20 +43,14 @@ export function complete( disableFunctions, signal, responseLanguage, + instructions, }: { client: Pick; getScreenContexts: () => ObservabilityAIAssistantScreenContext[]; - connectorId: string; - conversationId?: string; - messages: Message[]; - persist: boolean; - disableFunctions: boolean; - signal: AbortSignal; - responseLanguage: string; - }, + } & Parameters[0], requestCallback: ( params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat/complete'> - ) => Observable + ) => Observable ): Observable { return new Observable((subscriber) => { const screenContexts = getScreenContexts(); @@ -75,16 +66,10 @@ export function complete( screenContexts, conversationId, responseLanguage, + instructions, }, }, - }).pipe( - filter( - (event): event is StreamingChatResponseEvent => - event.type !== StreamingChatResponseEventType.BufferFlush - ), - throwSerializedChatCompletionErrors(), - shareReplay() - ); + }).pipe(shareReplay()); const messages$ = response$.pipe( filter( @@ -148,6 +133,7 @@ export function complete( persist, responseLanguage, disableFunctions, + instructions, }, requestCallback ).subscribe(subscriber); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts index 45fa95aa72a17..db3c8b1f5bbf3 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_chat_service.ts @@ -6,10 +6,9 @@ */ import type { AnalyticsServiceStart, HttpResponse } from '@kbn/core/public'; -import { AbortError } from '@kbn/kibana-utils-plugin/common'; import type { IncomingMessage } from 'http'; -import { pick } from 'lodash'; import { + catchError, concatMap, delay, filter, @@ -17,27 +16,30 @@ import { map, Observable, of, + OperatorFunction, scan, shareReplay, switchMap, + throwError, timestamp, } from 'rxjs'; -import { Message, MessageRole } from '../../common'; +import { ChatCompletionChunkEvent, Message, MessageRole } from '../../common'; import { - type BufferFlushEvent, StreamingChatResponseEventType, - type StreamingChatResponseEventWithoutError, + type BufferFlushEvent, type StreamingChatResponseEvent, + type StreamingChatResponseEventWithoutError, } from '../../common/conversation_complete'; -import { - FunctionRegistry, - FunctionResponse, - FunctionVisibility, -} from '../../common/functions/types'; +import { FunctionRegistry, FunctionResponse } from '../../common/functions/types'; import { filterFunctionDefinitions } from '../../common/utils/filter_function_definitions'; import { throwSerializedChatCompletionErrors } from '../../common/utils/throw_serialized_chat_completion_errors'; +import { untilAborted } from '../../common/utils/until_aborted'; import { sendEvent } from '../analytics'; -import type { ObservabilityAIAssistantAPIClient } from '../api'; +import type { + ObservabilityAIAssistantAPIClient, + ObservabilityAIAssistantAPIClientRequestParamsOf, + ObservabilityAIAssistantAPIEndpoint, +} from '../api'; import type { ChatRegistrationRenderFunction, ObservabilityAIAssistantChatService, @@ -91,6 +93,45 @@ function toObservable(response: HttpResponse) { ); } +function serialize( + signal: AbortSignal +): OperatorFunction { + return (source$) => + source$.pipe( + catchError((error) => { + if ( + 'response' in error && + 'json' in error.response && + typeof error.response.json === 'function' + ) { + const responseBodyPromise = (error.response as HttpResponse['response'])!.json(); + + return from( + responseBodyPromise.then((body: { message?: string }) => { + if (body) { + error.body = body; + if (body.message) { + error.message = body.message; + } + } + throw error; + }) + ); + } + return throwError(() => error); + }), + switchMap((readable) => toObservable(readable as HttpResponse)), + map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent), + filter( + (line): line is Exclude => + line.type !== StreamingChatResponseEventType.BufferFlush + ), + throwSerializedChatCompletionErrors(), + untilAborted(signal), + shareReplay() + ); +} + export async function createChatService({ analytics, signal: setupAbortSignal, @@ -130,73 +171,39 @@ export async function createChatService({ }); }; - const client: Pick = { - chat(name: string, { connectorId, messages, function: callFunctions = 'auto', signal }) { - return new Observable((subscriber) => { - const functions = getFunctions().filter((fn) => { - const visibility = fn.visibility ?? FunctionVisibility.All; - - return ( - visibility === FunctionVisibility.All || visibility === FunctionVisibility.AssistantOnly - ); - }); + function callStreamingApi( + endpoint: TEndpoint, + options: { + signal: AbortSignal; + } & ObservabilityAIAssistantAPIClientRequestParamsOf + ): Observable { + return from( + apiClient(endpoint, { + ...options, + asResponse: true, + rawResponse: true, + }) + ).pipe(serialize(options.signal)); + } - apiClient('POST /internal/observability_ai_assistant/chat', { - params: { - body: { - name, - messages, - connectorId, - functions: - callFunctions === 'none' - ? [] - : functions.map((fn) => pick(fn, 'name', 'description', 'parameters')), - }, + const client: Pick = { + chat(name: string, { connectorId, messages, functionCall, functions, signal }) { + return callStreamingApi('POST /internal/observability_ai_assistant/chat', { + params: { + body: { + name, + messages, + connectorId, + functionCall, + functions: functions ?? [], }, - signal, - asResponse: true, - rawResponse: true, - }) - .then((_response) => { - const response = _response as unknown as HttpResponse; - - const subscription = toObservable(response) - .pipe( - map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent), - filter( - (line): line is StreamingChatResponseEvent => - line.type !== StreamingChatResponseEventType.BufferFlush && - line.type !== StreamingChatResponseEventType.TokenCount - ), - throwSerializedChatCompletionErrors() - ) - .subscribe(subscriber); - - // if the request is aborted, convert that into state as well - signal.addEventListener('abort', () => { - subscriber.error(new AbortError()); - subscription.unsubscribe(); - }); - }) - .catch(async (err) => { - if ('response' in err) { - const body = await (err.response as HttpResponse['response'])?.json(); - err.body = body; - if (body.message) { - err.message = body.message; - } - } - throw err; - }) - .catch((err) => { - subscriber.error(err); - }); - - return subscriber; + }, + signal, }).pipe( - // make sure the request is only triggered once, - // even with multiple subscribers - shareReplay() + filter( + (line): line is ChatCompletionChunkEvent => + line.type === StreamingChatResponseEventType.ChatCompletionChunk + ) ); }, complete({ @@ -208,6 +215,7 @@ export async function createChatService({ disableFunctions, signal, responseLanguage, + instructions, }) { return complete( { @@ -220,21 +228,13 @@ export async function createChatService({ signal, client, responseLanguage, + instructions, }, ({ params }) => { - return from( - apiClient('POST /internal/observability_ai_assistant/chat/complete', { - params, - signal, - asResponse: true, - rawResponse: true, - }) - ).pipe( - map((_response) => toObservable(_response as unknown as HttpResponse)), - switchMap((response$) => response$), - map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent), - shareReplay() - ); + return callStreamingApi('POST /internal/observability_ai_assistant/chat/complete', { + params, + signal, + }); } ); }, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx index 6cad5a52ed2f8..d3b52f6803621 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/storybook_mock.tsx @@ -8,7 +8,7 @@ import { i18n } from '@kbn/i18n'; import { noop } from 'lodash'; import React from 'react'; import { Observable, of } from 'rxjs'; -import { MessageRole } from '.'; +import { ChatCompletionChunkEvent, MessageRole } from '.'; import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete'; import type { ObservabilityAIAssistantAPIClient } from './api'; import type { ObservabilityAIAssistantChatService, ObservabilityAIAssistantService } from './types'; @@ -16,7 +16,7 @@ import { buildFunctionElasticsearch, buildFunctionServiceSummary } from './utils export const createStorybookChatService = (): ObservabilityAIAssistantChatService => ({ sendAnalyticsEvent: () => {}, - chat: (options) => new Observable(), + chat: (options) => new Observable(), complete: (options) => new Observable(), getFunctions: () => [buildFunctionElasticsearch(), buildFunctionServiceSummary()], renderFunction: (name) => ( diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts index bfafbc4772462..8480af2e02327 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/types.ts @@ -9,6 +9,7 @@ import type { LicensingPluginStart } from '@kbn/licensing-plugin/public'; import type { SecurityPluginSetup, SecurityPluginStart } from '@kbn/security-plugin/public'; import type { Observable } from 'rxjs'; import type { + ChatCompletionChunkEvent, MessageAddEvent, StreamingChatResponseEventWithoutError, } from '../common/conversation_complete'; @@ -17,6 +18,7 @@ import type { Message, ObservabilityAIAssistantScreenContext, PendingMessage, + UserInstructionOrPlainText, } from '../common/types'; import type { TelemetryEventTypeWithPayload } from './analytics'; import type { ObservabilityAIAssistantAPIClient } from './api'; @@ -34,6 +36,13 @@ import { createScreenContextAction } from './utils/create_screen_context_action' export type { PendingMessage }; +export interface DiscoveredDataset { + title: string; + description: string; + indexPatterns: string[]; + columns: unknown[]; +} + export interface ObservabilityAIAssistantChatService { sendAnalyticsEvent: (event: TelemetryEventTypeWithPayload) => void; chat: ( @@ -41,19 +50,25 @@ export interface ObservabilityAIAssistantChatService { options: { messages: Message[]; connectorId: string; - function?: 'none' | 'auto'; + functions?: Array>; + functionCall?: string; signal: AbortSignal; } - ) => Observable; + ) => Observable; complete: (options: { getScreenContexts: () => ObservabilityAIAssistantScreenContext[]; conversationId?: string; connectorId: string; messages: Message[]; persist: boolean; - disableFunctions: boolean; + disableFunctions: + | boolean + | { + except: string[]; + }; signal: AbortSignal; - responseLanguage: string; + responseLanguage?: string; + instructions?: UserInstructionOrPlainText[]; }) => Observable; getFunctions: (options?: { contexts?: string[]; filter?: string }) => FunctionDefinition[]; hasFunction: (name: string) => boolean; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/utils/create_screen_context_action.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/utils/create_screen_context_action.ts index 3dbc4dbaf36f0..fcd6e8dd7bb80 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/utils/create_screen_context_action.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/utils/create_screen_context_action.ts @@ -18,11 +18,11 @@ type ReturnOf, - TResponse = ReturnOf + TRespondFunction extends ScreenContextActionRespondFunction> >( definition: TActionDefinition, - respond: ScreenContextActionRespondFunction -): ScreenContextActionDefinition { + respond: TRespondFunction +): ScreenContextActionDefinition> { return { ...definition, respond, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts index baf006844c516..c7df07dc18d2f 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts @@ -5,24 +5,16 @@ * 2.0. */ -import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils'; -import { Logger } from '@kbn/logging'; import type { Serializable } from '@kbn/utility-types'; -import dedent from 'dedent'; import { encode } from 'gpt-tokenizer'; -import * as t from 'io-ts'; -import { compact, last, omit } from 'lodash'; -import { lastValueFrom, Observable } from 'rxjs'; +import { compact, last } from 'lodash'; +import { Observable } from 'rxjs'; import { FunctionRegistrationParameters } from '.'; import { MessageAddEvent } from '../../common/conversation_complete'; import { FunctionVisibility } from '../../common/functions/types'; -import { MessageRole, type Message } from '../../common/types'; -import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_chat_completion_chunks'; +import { MessageRole } from '../../common/types'; import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message'; -import { RecallRanking, RecallRankingEventType } from '../analytics/recall_ranking'; -import type { ObservabilityAIAssistantClient } from '../service/client'; -import { FunctionCallChatFunction } from '../service/types'; -import { parseSuggestionScores } from './parse_suggestion_scores'; +import { recallAndScore } from '../utils/recall/recall_and_score'; const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000; @@ -70,55 +62,26 @@ export function registerContextFunction({ messages.filter((message) => message.message.role === MessageRole.User) ); - const userPrompt = userMessage?.message.content; - const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter( - ({ text }) => text - ) as Array<{ text: string; boost?: number }>; - - const suggestions = await retrieveSuggestions({ client, queries }); - if (suggestions.length === 0) { - return { content }; - } - - try { - const { relevantDocuments, scores } = await scoreSuggestions({ + const userPrompt = userMessage?.message.content!; + + const { scores, relevantDocuments, suggestions } = await recallAndScore({ + recall: client.recall, + chat, + logger: resources.logger, + userPrompt, + context: screenDescription, + messages, + signal, + analytics, + }); + + return { + content: { ...content, learnings: relevantDocuments as unknown as Serializable }, + data: { + scores, suggestions, - screenDescription, - userPrompt, - messages, - chat, - signal, - logger: resources.logger, - }); - - analytics.reportEvent(RecallRankingEventType, { - prompt: queries.map((query) => query.text).join('|'), - scoredDocuments: suggestions.map((suggestion) => { - const llmScore = scores.find((score) => score.id === suggestion.id); - return { - content: suggestion.text, - elserScore: suggestion.score ?? -1, - llmScore: llmScore ? llmScore.score : -1, - }; - }), - }); - - return { - content: { ...content, learnings: relevantDocuments as unknown as Serializable }, - data: { - scores, - suggestions, - }, - }; - } catch (error) { - return { - content: { ...content, learnings: suggestions.slice(0, 5) }, - data: { - error, - suggestions, - }, - }; - } + }, + }; } return new Observable((subscriber) => { @@ -141,146 +104,3 @@ export function registerContextFunction({ } ); } - -async function retrieveSuggestions({ - queries, - client, -}: { - queries: Array<{ text: string; boost?: number }>; - client: ObservabilityAIAssistantClient; -}) { - const recallResponse = await client.recall({ - queries, - }); - - return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction')); -} - -const scoreFunctionRequestRt = t.type({ - message: t.type({ - function_call: t.type({ - name: t.literal('score'), - arguments: t.string, - }), - }), -}); - -const scoreFunctionArgumentsRt = t.type({ - scores: t.string, -}); - -async function scoreSuggestions({ - suggestions, - messages, - userPrompt, - screenDescription, - chat, - signal, - logger, -}: { - suggestions: Awaited>; - messages: Message[]; - userPrompt: string | undefined; - screenDescription: string; - chat: FunctionCallChatFunction; - signal: AbortSignal; - logger: Logger; -}) { - const indexedSuggestions = suggestions.map((suggestion, index) => ({ - ...omit(suggestion, 'score'), // To not bias the LLM - id: index, - })); - - const newUserMessageContent = - dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7, - 0 being completely irrelevant, and 7 being extremely relevant. Information is relevant to the question if it helps in - answering the question. Judge it according to the following criteria: - - - The document is relevant to the question, and the rest of the conversation - - The document has information relevant to the question that is not mentioned, - or more detailed than what is available in the conversation - - The document has a high amount of information relevant to the question compared to other documents - - The document contains new information not mentioned before in the conversation - - Question: - ${userPrompt} - - Screen description: - ${screenDescription} - - Documents: - ${JSON.stringify(indexedSuggestions, null, 2)}`); - - const newUserMessage: Message = { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - content: newUserMessageContent, - }, - }; - - const scoreFunction = { - name: 'score', - description: - 'Use this function to score documents based on how relevant they are to the conversation.', - parameters: { - type: 'object', - properties: { - scores: { - description: `The document IDs and their scores, as CSV. Example: - - my_id,7 - my_other_id,3 - my_third_id,4 - `, - type: 'string', - }, - }, - required: ['score'], - } as const, - contexts: ['core'], - }; - - const response = await lastValueFrom( - chat('score_suggestions', { - messages: [...messages.slice(0, -2), newUserMessage], - functions: [scoreFunction], - functionCall: 'score', - signal, - }).pipe(concatenateChatCompletionChunks()) - ); - - const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response); - const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))( - scoreFunctionRequest.message.function_call.arguments - ); - - const scores = parseSuggestionScores(scoresAsString).map(({ index, score }) => { - return { - id: suggestions[index].id, - score, - }; - }); - - if (scores.length === 0) { - // seemingly invalid or no scores, return all - return { relevantDocuments: suggestions, scores: [] }; - } - - const suggestionIds = suggestions.map((document) => document.id); - - const relevantDocumentIds = scores - .filter((document) => suggestionIds.includes(document.id)) // Remove hallucinated documents - .filter((document) => document.score > 4) - .sort((a, b) => b.score - a.score) - .slice(0, 5) - .map((document) => document.id); - - const relevantDocuments = suggestions.filter((suggestion) => - relevantDocumentIds.includes(suggestion.id) - ); - - logger.debug(`Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`); - - return { relevantDocuments, scores }; -} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts index 2f32731ac3f2d..557f09784c7f9 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/get_dataset_info/get_relevant_field_names.ts @@ -9,7 +9,7 @@ import type { ElasticsearchClient, SavedObjectsClientContract } from '@kbn/core/ import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server'; import { castArray, chunk, groupBy, uniq } from 'lodash'; import { lastValueFrom } from 'rxjs'; -import { MessageRole, type Message } from '../../../common'; +import { MessageRole, ShortIdTable, type Message } from '../../../common'; import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks'; import { FunctionCallChatFunction } from '../../service/types'; @@ -87,8 +87,10 @@ export async function getRelevantFieldNames({ const groupedFields = groupBy(allFields, (field) => field.name); + const shortIdTable = new ShortIdTable(); + const relevantFields = await Promise.all( - chunk(fieldNames, 500).map(async (fieldsInChunk) => { + chunk(fieldNames, 250).map(async (fieldsInChunk) => { const chunkResponse$ = ( await chat('get_relevant_dataset_names', { signal, @@ -112,29 +114,31 @@ export async function getRelevantFieldNames({ role: MessageRole.User, content: `This is the list: - ${fieldsInChunk.join('\n')}`, + ${fieldsInChunk + .map((field) => JSON.stringify({ field, id: shortIdTable.take(field) })) + .join('\n')}`, }, }, ], functions: [ { - name: 'fields', - description: 'The fields you consider relevant to the conversation', + name: 'select_relevant_fields', + description: 'The IDs of the fields you consider relevant to the conversation', parameters: { type: 'object', properties: { - fields: { + fieldIds: { type: 'array', items: { type: 'string', }, }, }, - required: ['fields'], + required: ['fieldIds'], } as const, }, ], - functionCall: 'fields', + functionCall: 'select_relevant_fields', }) ).pipe(concatenateChatCompletionChunks()); @@ -143,10 +147,16 @@ export async function getRelevantFieldNames({ return chunkResponse.message?.function_call?.arguments ? ( JSON.parse(chunkResponse.message.function_call.arguments) as { - fields: string[]; + fieldIds: string[]; } - ).fields - .filter((field) => fieldsInChunk.includes(field)) + ).fieldIds + .map((fieldId) => { + const fieldName = shortIdTable.lookup(fieldId); + return fieldName ?? fieldId; + }) + .filter((fieldName) => { + return fieldsInChunk.includes(fieldName); + }) .map((field) => { const fieldDescriptors = groupedFields[field]; return `${field}:${fieldDescriptors.map((descriptor) => descriptor.type).join(',')}`; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/index.ts index 4cf8147d31c71..5b16b79bd9980 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/index.ts @@ -51,6 +51,9 @@ export const registerFunctions: RegistrationCallback = async ({ Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language. + If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results + returned to you, before executing the same tool or another tool again if needed. + DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (\`service.name == "foo"\`) with "kqlFilter" (\`service.name:"foo"\`). The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability, which can be found in the ${ @@ -63,7 +66,10 @@ export const registerFunctions: RegistrationCallback = async ({ functions.registerInstruction(({ availableFunctionNames }) => { const instructions: string[] = []; - if (availableFunctionNames.includes(GET_DATASET_INFO_FUNCTION_NAME)) { + if ( + availableFunctionNames.includes(QUERY_FUNCTION_NAME) && + availableFunctionNames.includes(GET_DATASET_INFO_FUNCTION_NAME) + ) { instructions.push(`You MUST use the "${GET_DATASET_INFO_FUNCTION_NAME}" ${ functions.hasFunction('get_apm_dataset_info') ? 'or the get_apm_dataset_info' : '' } function before calling the "${QUERY_FUNCTION_NAME}" or the "changes" functions. diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts index 41d3a6eaea5ce..f1758c1583f71 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/chat/route.ts @@ -6,20 +6,22 @@ */ import { notImplemented } from '@hapi/boom'; import { toBooleanRt } from '@kbn/io-ts-utils'; +import { context as otelContext } from '@opentelemetry/api'; import * as t from 'io-ts'; +import { from, map } from 'rxjs'; import { Readable } from 'stream'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; -import { KibanaRequest } from '@kbn/core/server'; -import { context as otelContext } from '@opentelemetry/api'; import { aiAssistantSimulatedFunctionCalling } from '../..'; +import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message'; +import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events'; +import { LangTracer } from '../../service/client/instrumentation/lang_tracer'; import { flushBuffer } from '../../service/util/flush_buffer'; import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream'; import { observableIntoStream } from '../../service/util/observable_into_stream'; +import { withAssistantSpan } from '../../service/util/with_assistant_span'; +import { recallAndScore } from '../../utils/recall/recall_and_score'; import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; -import { screenContextRt, messageRt, functionRt } from '../runtime_types'; +import { functionRt, messageRt, screenContextRt } from '../runtime_types'; import { ObservabilityAIAssistantRouteHandlerResources } from '../types'; -import { withAssistantSpan } from '../../service/util/with_assistant_span'; -import { LangTracer } from '../../service/client/instrumentation/lang_tracer'; const chatCompleteBaseRt = t.type({ body: t.intersection([ @@ -32,14 +34,24 @@ const chatCompleteBaseRt = t.type({ conversationId: t.string, title: t.string, responseLanguage: t.string, - disableFunctions: toBooleanRt, + disableFunctions: t.union([ + toBooleanRt, + t.type({ + except: t.array(t.string), + }), + ]), instructions: t.array( t.union([ t.string, - t.type({ - doc_id: t.string, - text: t.string, - }), + t.intersection([ + t.type({ + doc_id: t.string, + text: t.string, + }), + t.partial({ + system: t.boolean, + }), + ]), ]) ), }), @@ -67,17 +79,17 @@ const chatCompletePublicRt = t.intersection([ }), ]); -async function guardAgainstInvalidConnector({ - actions, +async function initializeChatRequest({ + context, request, - connectorId, -}: { - actions: ActionsPluginStart; - request: KibanaRequest; - connectorId: string; -}) { - return withAssistantSpan('guard_against_invalid_connector', async () => { - const actionsClient = await actions.getActionsClientWithRequest(request); + plugins: { cloud, actions }, + params: { + body: { connectorId }, + }, + service, +}: ObservabilityAIAssistantRouteHandlerResources & { params: { body: { connectorId: string } } }) { + await withAssistantSpan('guard_against_invalid_connector', async () => { + const actionsClient = await (await actions.start()).getActionsClientWithRequest(request); const connector = await actionsClient.get({ id: connectorId, @@ -86,6 +98,29 @@ async function guardAgainstInvalidConnector({ return connector; }); + + const [client, cloudStart, simulateFunctionCalling] = await Promise.all([ + service.getClient({ request }), + cloud?.start(), + (await context.core).uiSettings.client.get(aiAssistantSimulatedFunctionCalling), + ]); + + if (!client) { + throw notImplemented(); + } + + const controller = new AbortController(); + + request.events.aborted$.subscribe(() => { + controller.abort(); + }); + + return { + client, + isCloudEnabled: Boolean(cloudStart?.isCloudEnabled), + simulateFunctionCalling, + signal: controller.signal, + }; } const chatRoute = createObservabilityAIAssistantServerRoute({ @@ -107,38 +142,20 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ ]), }), handler: async (resources): Promise => { - const { request, params, service, context, plugins } = resources; + const { params } = resources; const { body: { name, messages, connectorId, functions, functionCall }, } = params; - await guardAgainstInvalidConnector({ - actions: await plugins.actions.start(), - request, - connectorId, - }); - - const [client, cloudStart, simulateFunctionCalling] = await Promise.all([ - service.getClient({ request }), - resources.plugins.cloud?.start(), - (await context.core).uiSettings.client.get(aiAssistantSimulatedFunctionCalling), - ]); - - if (!client) { - throw notImplemented(); - } - - const controller = new AbortController(); - - request.events.aborted$.subscribe(() => { - controller.abort(); - }); + const { client, simulateFunctionCalling, signal, isCloudEnabled } = await initializeChatRequest( + resources + ); const response$ = client.chat(name, { messages, connectorId, - signal: controller.signal, + signal, ...(functions.length ? { functions, @@ -149,7 +166,65 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ tracer: new LangTracer(otelContext.active()), }); - return observableIntoStream(response$.pipe(flushBuffer(!!cloudStart?.isCloudEnabled))); + return observableIntoStream(response$.pipe(flushBuffer(isCloudEnabled))); + }, +}); + +const chatRecallRoute = createObservabilityAIAssistantServerRoute({ + endpoint: 'POST /internal/observability_ai_assistant/chat/recall', + options: { + tags: ['access:ai_assistant'], + }, + params: t.type({ + body: t.type({ + prompt: t.string, + context: t.string, + connectorId: t.string, + }), + }), + handler: async (resources): Promise => { + const { client, simulateFunctionCalling, signal, isCloudEnabled } = await initializeChatRequest( + resources + ); + + const { connectorId, prompt, context } = resources.params.body; + + const response$ = from( + recallAndScore({ + analytics: (await resources.context.core).coreStart.analytics, + chat: (name, params) => + client + .chat(name, { + ...params, + connectorId, + simulateFunctionCalling, + signal, + tracer: new LangTracer(otelContext.active()), + }) + .pipe(withoutTokenCountEvents()), + context, + logger: resources.logger, + messages: [], + userPrompt: prompt, + recall: client.recall, + signal, + }) + ).pipe( + map(({ scores, suggestions, relevantDocuments }) => { + return createFunctionResponseMessage({ + name: 'context', + data: { + suggestions, + scores, + }, + content: { + relevantDocuments, + }, + }); + }) + ); + + return observableIntoStream(response$.pipe(flushBuffer(isCloudEnabled))); }, }); @@ -158,7 +233,7 @@ async function chatComplete( params: t.TypeOf; } ) { - const { request, params, service, plugins } = resources; + const { params, service } = resources; const { body: { @@ -174,32 +249,12 @@ async function chatComplete( }, } = params; - await guardAgainstInvalidConnector({ - actions: await plugins.actions.start(), - request, - connectorId, - }); - - const [client, cloudStart, simulateFunctionCalling] = await Promise.all([ - service.getClient({ request }), - resources.plugins.cloud?.start() || Promise.resolve(undefined), - ( - await resources.context.core - ).uiSettings.client.get(aiAssistantSimulatedFunctionCalling), - ]); - - if (!client) { - throw notImplemented(); - } - - const controller = new AbortController(); - - request.events.aborted$.subscribe(() => { - controller.abort(); - }); + const { client, isCloudEnabled, signal, simulateFunctionCalling } = await initializeChatRequest( + resources + ); const functionClient = await service.getFunctionClient({ - signal: controller.signal, + signal, resources, client, screenContexts, @@ -211,7 +266,7 @@ async function chatComplete( conversationId, title, persist, - signal: controller.signal, + signal, functionClient, responseLanguage, instructions, @@ -219,7 +274,7 @@ async function chatComplete( disableFunctions, }); - return response$.pipe(flushBuffer(!!cloudStart?.isCloudEnabled)); + return response$.pipe(flushBuffer(isCloudEnabled)); } const chatCompleteRoute = createObservabilityAIAssistantServerRoute({ @@ -271,6 +326,7 @@ const publicChatCompleteRoute = createObservabilityAIAssistantServerRoute({ export const chatRoutes = { ...chatRoute, + ...chatRecallRoute, ...chatCompleteRoute, ...publicChatCompleteRoute, }; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_openai_stream.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_openai_stream.ts index 908042770ea2d..59dbd24451c09 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_openai_stream.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/adapters/process_openai_stream.ts @@ -5,7 +5,7 @@ * 2.0. */ import { encode } from 'gpt-tokenizer'; -import { first, sum } from 'lodash'; +import { first, memoize, sum } from 'lodash'; import OpenAI from 'openai'; import { filter, map, Observable, tap } from 'rxjs'; import { v4 } from 'uuid'; @@ -51,6 +51,14 @@ export function processOpenAiStream({ }); } + const warnForToolCall = memoize( + (toolCall: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall) => { + logger.warn(`More tools than 1 were called: ${JSON.stringify(toolCall)}`); + }, + (toolCall: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall) => + toolCall.index + ); + const parsed$ = source.pipe( filter((line) => !!line && line !== '[DONE]'), map( @@ -76,7 +84,16 @@ export function processOpenAiStream({ firstChoice?.delta.content, firstChoice?.delta.function_call?.name, firstChoice?.delta.function_call?.arguments, - ].map((val) => encode(val || '').length) || 0 + ...(firstChoice?.delta.tool_calls?.flatMap((toolCall) => { + return [ + toolCall.function?.name, + toolCall.function?.arguments, + toolCall.id, + toolCall.index, + toolCall.type, + ]; + }) ?? []), + ].map((val) => encode(val?.toString() ?? '').length) || 0 ); }), filter( @@ -85,8 +102,17 @@ export function processOpenAiStream({ ), map((chunk): ChatCompletionChunkEvent => { const delta = chunk.choices[0].delta; - if (delta.tool_calls && delta.tool_calls.length > 1) { - logger.warn(`More tools than 1 were called: ${JSON.stringify(delta.tool_calls)}`); + if (delta.tool_calls && (delta.tool_calls.length > 1 || delta.tool_calls[0].index > 0)) { + delta.tool_calls.forEach((toolCall) => { + warnForToolCall(toolCall); + }); + return { + id, + type: StreamingChatResponseEventType.ChatCompletionChunk, + message: { + content: delta.content ?? '', + }, + }; } const functionCall: Omit | undefined = diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts index 0349d597b7ba0..46d72d303f7e3 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts @@ -27,6 +27,7 @@ import { createFunctionResponseMessage } from '../../../common/utils/create_func import { CONTEXT_FUNCTION_NAME } from '../../functions/context'; import { ChatFunctionClient } from '../chat_function_client'; import type { KnowledgeBaseService } from '../knowledge_base_service'; +import { USER_INSTRUCTIONS_HEADER } from '../util/get_system_message_from_instructions'; import { observableIntoStream } from '../util/observable_into_stream'; import { CreateChatCompletionResponseChunk } from './adapters/process_openai_stream'; @@ -34,7 +35,7 @@ type ChunkDelta = CreateChatCompletionResponseChunk['choices'][number]['delta']; type LlmSimulator = ReturnType; -const EXPECTED_STORED_SYSTEM_MESSAGE = `system\n\nWhat follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:\n\nYou MUST respond in the users preferred language which is: English.`; +const EXPECTED_STORED_SYSTEM_MESSAGE = `system\n\n${USER_INSTRUCTIONS_HEADER}\n\nYou MUST respond in the users preferred language which is: English.`; const nextTick = () => { return new Promise(process.nextTick); @@ -368,8 +369,8 @@ describe('Observability AI Assistant client', () => { last_updated: expect.any(String), token_count: { completion: 1, - prompt: 78, - total: 79, + prompt: 84, + total: 85, }, }, type: StreamingChatResponseEventType.ConversationCreate, @@ -425,8 +426,8 @@ describe('Observability AI Assistant client', () => { last_updated: expect.any(String), token_count: { completion: 6, - prompt: 262, - total: 268, + prompt: 268, + total: 274, }, }, type: StreamingChatResponseEventType.ConversationCreate, @@ -443,8 +444,8 @@ describe('Observability AI Assistant client', () => { title: 'An auto-generated title', token_count: { completion: 6, - prompt: 262, - total: 268, + prompt: 268, + total: 274, }, }, labels: {}, @@ -574,8 +575,8 @@ describe('Observability AI Assistant client', () => { last_updated: expect.any(String), token_count: { completion: 2, - prompt: 156, - total: 158, + prompt: 162, + total: 164, }, }, type: StreamingChatResponseEventType.ConversationUpdate, @@ -593,8 +594,8 @@ describe('Observability AI Assistant client', () => { title: 'My stored conversation', token_count: { completion: 2, - prompt: 156, - total: 158, + prompt: 162, + total: 164, }, }, labels: {}, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts index 9739a59125011..dacd52648a6b8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -45,7 +45,7 @@ import { } from '../../../common/conversation_complete'; import { CompatibleJSONSchema } from '../../../common/functions/types'; import { - UserInstruction, + UserInstructionOrPlainText, type Conversation, type ConversationCreateRequest, type ConversationUpdateRequest, @@ -170,9 +170,13 @@ export class ObservabilityAIAssistantClient { title?: string; isPublic?: boolean; kibanaPublicUrl?: string; - instructions?: Array; + instructions?: UserInstructionOrPlainText[]; simulateFunctionCalling?: boolean; - disableFunctions?: boolean; + disableFunctions?: + | boolean + | { + except: string[]; + }; }): Observable> => { return new LangTracer(context.active()).startActiveSpan( 'complete', diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts index 2ab26cb4799ae..83d9bf37e7efb 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts @@ -133,13 +133,17 @@ function getFunctionDefinitions({ }: { functionClient: ChatFunctionClient; functionLimitExceeded: boolean; - disableFunctions: boolean; + disableFunctions: + | boolean + | { + except: string[]; + }; }) { - if (functionLimitExceeded || disableFunctions) { + if (functionLimitExceeded || disableFunctions === true) { return []; } - const systemFunctions = functionClient + let systemFunctions = functionClient .getFunctions() .map((fn) => fn.definition) .filter( @@ -148,6 +152,10 @@ function getFunctionDefinitions({ [FunctionVisibility.AssistantOnly, FunctionVisibility.All].includes(def.visibility) ); + if (typeof disableFunctions === 'object') { + systemFunctions = systemFunctions.filter((fn) => disableFunctions.except.includes(fn.name)); + } + const actions = functionClient.getActions(); const allDefinitions = systemFunctions @@ -177,7 +185,11 @@ export function continueConversation({ requestInstructions: Array; userInstructions: UserInstruction[]; logger: Logger; - disableFunctions: boolean; + disableFunctions: + | boolean + | { + except: string[]; + }; tracer: LangTracer; }): Observable { let nextFunctionCallsLeft = functionCallsLeft; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts index 7c504aa43c38c..67cf8bcd000a9 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts @@ -309,7 +309,7 @@ export class KnowledgeBaseService { user?: { name: string }; modelId: string; }): Promise { - const query = { + const esQuery = { bool: { should: queries.map(({ text, boost = 1 }) => ({ text_expansion: { @@ -334,7 +334,7 @@ export class KnowledgeBaseService { Pick >({ index: [this.dependencies.resources.aliases.kb], - query, + query: esQuery, size: 20, _source: { includes: ['text', 'is_correction', 'labels'], @@ -481,7 +481,9 @@ export class KnowledgeBaseService { }): Promise<{ entries: RecalledEntry[]; }> => { - this.dependencies.logger.debug(`Recalling entries from KB for queries: "${queries}"`); + this.dependencies.logger.debug( + `Recalling entries from KB for queries: "${JSON.stringify(queries)}"` + ); const modelId = await this.dependencies.getModelId(); const [documentsFromKb, documentsFromConnectors] = await Promise.all([ diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.test.ts index 99a2c34bc33d7..93594fc520998 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.test.ts @@ -4,7 +4,10 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import { getSystemMessageFromInstructions } from './get_system_message_from_instructions'; +import { + getSystemMessageFromInstructions, + USER_INSTRUCTIONS_HEADER, +} from './get_system_message_from_instructions'; describe('getSystemMessageFromInstructions', () => { it('handles plain instructions', () => { @@ -42,9 +45,7 @@ describe('getSystemMessageFromInstructions', () => { requestInstructions: [{ doc_id: 'second', text: 'second_request' }], availableFunctionNames: [], }) - ).toEqual( - `first\n\nWhat follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:\n\nsecond_request` - ); + ).toEqual(`first\n\n${USER_INSTRUCTIONS_HEADER}\n\nsecond_request`); }); it('includes kb instructions if there is no request instruction', () => { @@ -55,9 +56,7 @@ describe('getSystemMessageFromInstructions', () => { requestInstructions: [], availableFunctionNames: [], }) - ).toEqual( - `first\n\nWhat follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:\n\nsecond_kb` - ); + ).toEqual(`first\n\n${USER_INSTRUCTIONS_HEADER}\n\nsecond_kb`); }); it('handles undefined values', () => { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.ts index ece79b9f78485..759ff07125b95 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/util/get_system_message_from_instructions.ts @@ -5,12 +5,19 @@ * 2.0. */ -import { compact } from 'lodash'; +import { compact, partition } from 'lodash'; import { v4 } from 'uuid'; -import { UserInstruction } from '../../../common/types'; +import { UserInstruction, UserInstructionOrPlainText } from '../../../common/types'; import { withTokenBudget } from '../../../common/utils/with_token_budget'; import { RegisteredInstruction } from '../types'; +export const USER_INSTRUCTIONS_HEADER = `## User instructions + +What follows is a set of instructions provided by the user, please abide by them +as long as they don't conflict with anything you've been told so far: + +`; + export function getSystemMessageFromInstructions({ registeredInstructions, userInstructions, @@ -19,7 +26,7 @@ export function getSystemMessageFromInstructions({ }: { registeredInstructions: RegisteredInstruction[]; userInstructions: UserInstruction[]; - requestInstructions: Array; + requestInstructions: UserInstructionOrPlainText[]; availableFunctionNames: string[]; }): string { const allRegisteredInstructions = compact( @@ -32,10 +39,17 @@ export function getSystemMessageFromInstructions({ ); const requestInstructionsWithId = requestInstructions.map((instruction) => - typeof instruction === 'string' ? { doc_id: v4(), text: instruction } : instruction + typeof instruction === 'string' + ? { doc_id: v4(), text: instruction, system: false } + : instruction + ); + + const [requestSystemInstructions, requestUserInstructionsWithId] = partition( + requestInstructionsWithId, + (instruction) => instruction.system === true ); - const requestOverrideIds = requestInstructionsWithId.map((instruction) => instruction.doc_id); + const requestOverrideIds = requestUserInstructionsWithId.map((instruction) => instruction.doc_id); // all request instructions, and those from the KB that are not defined as a request instruction const allUserInstructions = requestInstructionsWithId.concat( @@ -45,12 +59,9 @@ export function getSystemMessageFromInstructions({ const instructionsWithinBudget = withTokenBudget(allUserInstructions, 1000); return [ - ...allRegisteredInstructions, + ...allRegisteredInstructions.concat(requestSystemInstructions), ...(instructionsWithinBudget.length - ? [ - `What follows is a set of instructions provided by the user, please abide by them as long as they don't conflict with anything you've been told so far:`, - ...instructionsWithinBudget, - ] + ? [USER_INSTRUCTIONS_HEADER, ...instructionsWithinBudget] : []), ] .map((instruction) => { diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/parse_suggestion_scores.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/parse_suggestion_scores.test.ts similarity index 71% rename from x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/parse_suggestion_scores.test.ts rename to x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/parse_suggestion_scores.test.ts index 7b62cf21af65b..abeeda3c37657 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/parse_suggestion_scores.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/parse_suggestion_scores.test.ts @@ -12,56 +12,56 @@ describe('parseSuggestionScores', () => { expect( parseSuggestionScores( dedent( - `0,1 - 2,7 - 3,10` + `my-id,1 + my-other-id,7 + my-another-id,10` ) ) ).toEqual([ { - index: 0, + id: 'my-id', score: 1, }, { - index: 2, + id: 'my-other-id', score: 7, }, { - index: 3, + id: 'my-another-id', score: 10, }, ]); }); it('parses semi-colons as separators', () => { - expect(parseSuggestionScores(`0,1;2,7;3,10`)).toEqual([ + expect(parseSuggestionScores(`idone,1;idtwo,7;idthree,10`)).toEqual([ { - index: 0, + id: 'idone', score: 1, }, { - index: 2, + id: 'idtwo', score: 7, }, { - index: 3, + id: 'idthree', score: 10, }, ]); }); it('parses spaces as separators', () => { - expect(parseSuggestionScores(`0,1 2,7 3,10`)).toEqual([ + expect(parseSuggestionScores(`a,1 b,7 c,10`)).toEqual([ { - index: 0, + id: 'a', score: 1, }, { - index: 2, + id: 'b', score: 7, }, { - index: 3, + id: 'c', score: 10, }, ]); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/parse_suggestion_scores.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/parse_suggestion_scores.ts similarity index 77% rename from x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/parse_suggestion_scores.ts rename to x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/parse_suggestion_scores.ts index 9fa39bf1233b5..464504bed85a8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/parse_suggestion_scores.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/parse_suggestion_scores.ts @@ -8,15 +8,15 @@ export function parseSuggestionScores(scoresAsString: string) { // make sure that spaces, semi-colons etc work as separators as well const scores = scoresAsString - .replace(/[^0-9,]/g, ' ') + .replace(/[^0-9a-zA-Z\-_,]/g, ' ') .trim() .split(/\s+/) .map((pair) => { - const [index, score] = pair.split(',').map((str) => parseInt(str, 10)); + const [id, score] = pair.split(',').map((str) => str.trim()); return { - index, - score, + id, + score: parseInt(score, 10), }; }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/recall_and_score.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/recall_and_score.ts new file mode 100644 index 0000000000000..8885ff7e1d7a2 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/recall_and_score.ts @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { Logger } from '@kbn/logging'; +import { AnalyticsServiceStart } from '@kbn/core/server'; +import type { Message } from '../../../common'; +import type { ObservabilityAIAssistantClient } from '../../service/client'; +import type { FunctionCallChatFunction } from '../../service/types'; +import { retrieveSuggestions } from './retrieve_suggestions'; +import { scoreSuggestions } from './score_suggestions'; +import type { RetrievedSuggestion } from './types'; +import { RecallRanking, RecallRankingEventType } from '../../analytics/recall_ranking'; + +export async function recallAndScore({ + recall, + chat, + analytics, + userPrompt, + context, + messages, + logger, + signal, +}: { + recall: ObservabilityAIAssistantClient['recall']; + chat: FunctionCallChatFunction; + analytics: AnalyticsServiceStart; + userPrompt: string; + context: string; + messages: Message[]; + logger: Logger; + signal: AbortSignal; +}): Promise<{ + relevantDocuments?: RetrievedSuggestion[]; + scores?: Array<{ id: string; score: number }>; + suggestions: RetrievedSuggestion[]; +}> { + const queries = [ + { text: userPrompt, boost: 3 }, + { text: context, boost: 1 }, + ].filter((query) => query.text.trim()); + + const suggestions = await retrieveSuggestions({ + recall, + queries, + }); + + if (!suggestions.length) { + return { + relevantDocuments: [], + scores: [], + suggestions: [], + }; + } + + try { + const { scores, relevantDocuments } = await scoreSuggestions({ + suggestions, + logger, + messages, + userPrompt, + context, + signal, + chat, + }); + + analytics.reportEvent(RecallRankingEventType, { + prompt: queries.map((query) => query.text).join('\n\n'), + scoredDocuments: suggestions.map((suggestion) => { + const llmScore = scores.find((score) => score.id === suggestion.id); + return { + content: suggestion.text, + elserScore: suggestion.score ?? -1, + llmScore: llmScore ? llmScore.score : -1, + }; + }), + }); + + return { scores, relevantDocuments, suggestions }; + } catch (error) { + logger.error(`Error scoring documents: ${error.message}`, { error }); + return { + suggestions: suggestions.slice(0, 5), + }; + } +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/retrieve_suggestions.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/retrieve_suggestions.ts new file mode 100644 index 0000000000000..3c680229cd5d2 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/retrieve_suggestions.ts @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { omit } from 'lodash'; +import { ObservabilityAIAssistantClient } from '../../service/client'; +import { RetrievedSuggestion } from './types'; + +export async function retrieveSuggestions({ + queries, + recall, +}: { + queries: Array<{ text: string; boost?: number }>; + recall: ObservabilityAIAssistantClient['recall']; +}): Promise { + const recallResponse = await recall({ + queries, + }); + + return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction')); +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/score_suggestions.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/score_suggestions.ts new file mode 100644 index 0000000000000..b6a16d6329aec --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/score_suggestions.ts @@ -0,0 +1,164 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import * as t from 'io-ts'; +import { omit } from 'lodash'; +import { Logger } from '@kbn/logging'; +import dedent from 'dedent'; +import { lastValueFrom } from 'rxjs'; +import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils'; +import { concatenateChatCompletionChunks, Message, MessageRole } from '../../../common'; +import type { FunctionCallChatFunction } from '../../service/types'; +import type { RetrievedSuggestion } from './types'; +import { parseSuggestionScores } from './parse_suggestion_scores'; +import { ShortIdTable } from '../../../common/utils/short_id_table'; + +const scoreFunctionRequestRt = t.type({ + message: t.type({ + function_call: t.type({ + name: t.literal('score'), + arguments: t.string, + }), + }), +}); + +const scoreFunctionArgumentsRt = t.type({ + scores: t.string, +}); + +export async function scoreSuggestions({ + suggestions, + messages, + userPrompt, + context, + chat, + signal, + logger, +}: { + suggestions: RetrievedSuggestion[]; + messages: Message[]; + userPrompt: string; + context: string; + chat: FunctionCallChatFunction; + signal: AbortSignal; + logger: Logger; +}): Promise<{ + relevantDocuments: RetrievedSuggestion[]; + scores: Array<{ id: string; score: number }>; +}> { + const shortIdTable = new ShortIdTable(); + + const suggestionsWithShortId = suggestions.map((suggestion) => ({ + ...omit(suggestion, 'score', 'id'), // To not bias the LLM + originalId: suggestion.id, + shortId: shortIdTable.take(suggestion.id), + })); + + const newUserMessageContent = + dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7, + 0 being completely irrelevant, and 7 being extremely relevant. Information is relevant to the question if it helps in + answering the question. Judge it according to the following criteria: + + - The document is relevant to the question, and the rest of the conversation + - The document has information relevant to the question that is not mentioned, + or more detailed than what is available in the conversation + - The document has a high amount of information relevant to the question compared to other documents + - The document contains new information not mentioned before in the conversation + + User prompt: + ${userPrompt} + + Context: + ${context} + + Documents: + ${JSON.stringify( + suggestionsWithShortId.map((suggestion) => ({ + id: suggestion.shortId, + content: suggestion.text, + })), + null, + 2 + )}`); + + const newUserMessage: Message = { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: newUserMessageContent, + }, + }; + + const scoreFunction = { + name: 'score', + description: + 'Use this function to score documents based on how relevant they are to the conversation.', + parameters: { + type: 'object', + properties: { + scores: { + description: `The document IDs and their scores, as CSV. Example: + + my_id,7 + my_other_id,3 + my_third_id,4 + `, + type: 'string', + }, + }, + required: ['score'], + } as const, + }; + + const response = await lastValueFrom( + chat('score_suggestions', { + messages: [...messages.slice(0, -2), newUserMessage], + functions: [scoreFunction], + functionCall: 'score', + signal, + }).pipe(concatenateChatCompletionChunks()) + ); + + const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response); + const { scores: scoresAsString } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))( + scoreFunctionRequest.message.function_call.arguments + ); + + const scores = parseSuggestionScores(scoresAsString).map(({ id, score }) => { + const originalSuggestion = suggestionsWithShortId.find( + (suggestion) => suggestion.shortId === id + ); + return { + originalId: originalSuggestion?.originalId, + score, + }; + }); + + if (scores.length === 0) { + // seemingly invalid or no scores, return all + return { relevantDocuments: suggestions, scores: [] }; + } + + const suggestionIds = suggestions.map((document) => document.id); + + const relevantDocumentIds = scores + .filter((document) => suggestionIds.includes(document.originalId ?? '')) // Remove hallucinated documents + .filter((document) => document.score > 4) + .sort((a, b) => b.score - a.score) + .slice(0, 5) + .map((document) => document.originalId); + + const relevantDocuments = suggestions.filter((suggestion) => + relevantDocumentIds.includes(suggestion.id) + ); + + logger.debug(`Relevant documents: ${JSON.stringify(relevantDocuments, null, 2)}`); + + return { + relevantDocuments, + scores: scores.map((score) => ({ id: score.originalId!, score: score.score })), + }; +} diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/types.ts new file mode 100644 index 0000000000000..3774df64c1ee1 --- /dev/null +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/utils/recall/types.ts @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { RecalledEntry } from '../../service/knowledge_base_service'; + +export type RetrievedSuggestion = Omit; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.test.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.test.tsx index 836b8f6ef7f93..620fbbc2ab166 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.test.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.test.tsx @@ -5,7 +5,7 @@ * 2.0. */ import React from 'react'; -import { render, screen, waitFor } from '@testing-library/react'; +import { render, screen, waitFor, act } from '@testing-library/react'; import userEvent from '@testing-library/user-event'; import type { DatatableColumn } from '@kbn/expressions-plugin/common'; import type { LensPublicStart } from '@kbn/lens-plugin/public'; @@ -142,7 +142,8 @@ describe('VisualizeESQL', () => { }), }; renderComponent({}, lensService, undefined, ['There is an error mate']); - await waitFor(() => expect(screen.findByTestId('observabilityAiAssistantErrorsList'))); + + expect(await screen.findByTestId('observabilityAiAssistantErrorsList')).toBeInTheDocument(); }); it('should not display the table on first render', async () => { @@ -153,15 +154,16 @@ describe('VisualizeESQL', () => { suggestions: jest.fn(), }), }; + renderComponent({}, lensService); - // the button to render a table should be present - await waitFor(() => - expect(screen.findByTestId('observabilityAiAssistantLensESQLDisplayTableButton')) - ); - await waitFor(() => - expect(screen.queryByTestId('observabilityAiAssistantESQLDataGrid')).not.toBeInTheDocument() - ); + expect( + await screen.findByTestId('observabilityAiAssistantLensESQLDisplayTableButton') + ).toBeInTheDocument(); + + expect( + await screen.queryByTestId('observabilityAiAssistantESQLDataGrid') + ).not.toBeInTheDocument(); }); it('should display the table when user clicks the table button', async () => { @@ -172,11 +174,16 @@ describe('VisualizeESQL', () => { suggestions: jest.fn(), }), }; + renderComponent({}, lensService); - await waitFor(() => { - userEvent.click(screen.getByTestId('observabilityAiAssistantLensESQLDisplayTableButton')); - expect(screen.findByTestId('observabilityAiAssistantESQLDataGrid')); + + await act(async () => { + userEvent.click( + await screen.findByTestId('observabilityAiAssistantLensESQLDisplayTableButton') + ); }); + + expect(await screen.findByTestId('observabilityAiAssistantESQLDataGrid')).toBeInTheDocument(); }); it('should render the ESQLDataGrid if Lens returns a table', async () => { @@ -195,8 +202,6 @@ describe('VisualizeESQL', () => { }, lensService ); - await waitFor(() => { - expect(screen.findByTestId('observabilityAiAssistantESQLDataGrid')); - }); + expect(await screen.findByTestId('observabilityAiAssistantESQLDataGrid')).toBeInTheDocument(); }); });