diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts index 7232078d2efe8..9e0adc5a94d8f 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/public/service/create_service.ts @@ -72,6 +72,10 @@ export function createService({ return of( createFunctionRequestMessage({ name: 'context', + args: { + queries: [], + categories: [], + }, }), createFunctionResponseMessage({ name: 'context', diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts index baf006844c516..4bc32a2330acd 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/functions/context.ts @@ -40,10 +40,34 @@ export function registerContextFunction({ description: 'This function provides context as to what the user is looking at on their screen, and recalled documents from the knowledge base that matches their query', visibility: FunctionVisibility.Internal, + parameters: { + type: 'object', + properties: { + queries: { + type: 'array', + description: 'The query for the semantic search', + items: { + type: 'string', + }, + }, + categories: { + type: 'array', + description: + 'Categories of internal documentation that you want to search for. By default internal documentation will be excluded. Use `apm` to get internal APM documentation, `lens` to get internal Lens documentation, or both.', + items: { + type: 'string', + enum: ['apm', 'lens'], + }, + }, + }, + required: ['queries', 'categories'], + } as const, }, - async ({ messages, screenContexts, chat }, signal) => { + async ({ arguments: args, messages, screenContexts, chat }, signal) => { const { analytics } = (await resources.context.core).coreStart; + const { queries, categories } = args; + async function getContext() { const screenDescription = compact( screenContexts.map((context) => context.screenDescription) @@ -70,21 +94,30 @@ export function registerContextFunction({ messages.filter((message) => message.message.role === MessageRole.User) ); - const userPrompt = userMessage?.message.content; - const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter( - ({ text }) => text - ) as Array<{ text: string; boost?: number }>; + const nonEmptyQueries = compact(queries); + + const queriesOrUserPrompt = nonEmptyQueries.length + ? nonEmptyQueries + : compact([userMessage?.message.content]); + + queriesOrUserPrompt.push(screenDescription); + + const suggestions = await retrieveSuggestions({ + client, + categories, + queries: queriesOrUserPrompt, + }); - const suggestions = await retrieveSuggestions({ client, queries }); if (suggestions.length === 0) { - return { content }; + return { + content, + }; } try { const { relevantDocuments, scores } = await scoreSuggestions({ suggestions, - screenDescription, - userPrompt, + queries: queriesOrUserPrompt, messages, chat, signal, @@ -92,7 +125,7 @@ export function registerContextFunction({ }); analytics.reportEvent(RecallRankingEventType, { - prompt: queries.map((query) => query.text).join('|'), + prompt: queriesOrUserPrompt.join('|'), scoredDocuments: suggestions.map((suggestion) => { const llmScore = scores.find((score) => score.id === suggestion.id); return { @@ -145,12 +178,15 @@ export function registerContextFunction({ async function retrieveSuggestions({ queries, client, + categories, }: { - queries: Array<{ text: string; boost?: number }>; + queries: string[]; client: ObservabilityAIAssistantClient; + categories: Array<'apm' | 'lens'>; }) { const recallResponse = await client.recall({ queries, + categories, }); return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction')); @@ -172,16 +208,14 @@ const scoreFunctionArgumentsRt = t.type({ async function scoreSuggestions({ suggestions, messages, - userPrompt, - screenDescription, + queries, chat, signal, logger, }: { suggestions: Awaited>; messages: Message[]; - userPrompt: string | undefined; - screenDescription: string; + queries: string[]; chat: FunctionCallChatFunction; signal: AbortSignal; logger: Logger; @@ -203,10 +237,7 @@ async function scoreSuggestions({ - The document contains new information not mentioned before in the conversation Question: - ${userPrompt} - - Screen description: - ${screenDescription} + ${queries.join('\n')} Documents: ${JSON.stringify(indexedSuggestions, null, 2)}`); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts index 52be33c2a372d..8d509271c1e37 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/routes/functions/route.ts @@ -65,16 +65,7 @@ const functionRecallRoute = createObservabilityAIAssistantServerRoute({ params: t.type({ body: t.intersection([ t.type({ - queries: t.array( - t.intersection([ - t.type({ - text: t.string, - }), - t.partial({ - boost: t.number, - }), - ]) - ), + queries: t.array(nonEmptyStringRt), }), t.partial({ categories: t.array(t.string), diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts index e5ea0ad0ff829..74cc19d8aa153 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/get_context_function_request_if_needed.ts @@ -28,5 +28,9 @@ export function getContextFunctionRequestIfNeeded( return createFunctionRequestMessage({ name: CONTEXT_FUNCTION_NAME, + args: { + queries: [], + categories: [], + }, }); } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts index 0349d597b7ba0..4ffc8dc926fc7 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts @@ -1232,6 +1232,7 @@ describe('Observability AI Assistant client', () => { role: MessageRole.Assistant, function_call: { name: CONTEXT_FUNCTION_NAME, + arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, @@ -1455,6 +1456,7 @@ describe('Observability AI Assistant client', () => { role: MessageRole.Assistant, function_call: { name: CONTEXT_FUNCTION_NAME, + arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts index 9739a59125011..803e0e904223e 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -694,7 +694,7 @@ export class ObservabilityAIAssistantClient { queries, categories, }: { - queries: Array<{ text: string; boost?: number }>; + queries: string[]; categories?: string[]; }): Promise<{ entries: RecalledEntry[] }> => { return this.dependencies.knowledgeBaseService.recall({ @@ -757,9 +757,11 @@ export class ObservabilityAIAssistantClient { }; fetchUserInstructions = async () => { - return this.dependencies.knowledgeBaseService.getUserInstructions( + const userInstructions = await this.dependencies.knowledgeBaseService.getUserInstructions( this.dependencies.namespace, this.dependencies.user ); + + return userInstructions; }; } diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts index 7c504aa43c38c..576fd8dc5552b 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/knowledge_base_service/index.ts @@ -303,7 +303,7 @@ export class KnowledgeBaseService { user, modelId, }: { - queries: Array<{ text: string; boost?: number }>; + queries: string[]; categories?: string[]; namespace: string; user?: { name: string }; @@ -311,12 +311,11 @@ export class KnowledgeBaseService { }): Promise { const query = { bool: { - should: queries.map(({ text, boost = 1 }) => ({ + should: queries.map((text) => ({ text_expansion: { 'ml.tokens': { model_text: text, model_id: modelId, - boost, }, }, })), @@ -386,7 +385,7 @@ export class KnowledgeBaseService { uiSettingsClient, modelId, }: { - queries: Array<{ text: string; boost?: number }>; + queries: string[]; asCurrentUser: ElasticsearchClient; uiSettingsClient: IUiSettingsClient; modelId: string; @@ -415,16 +414,15 @@ export class KnowledgeBaseService { const vectorField = `${ML_INFERENCE_PREFIX}${field}_expanded.predicted_value`; const modelField = `${ML_INFERENCE_PREFIX}${field}_expanded.model_id`; - return queries.map(({ text, boost = 1 }) => { + return queries.map((query) => { return { bool: { should: [ { text_expansion: { [vectorField]: { - model_text: text, + model_text: query, model_id: modelId, - boost, }, }, }, @@ -472,7 +470,7 @@ export class KnowledgeBaseService { asCurrentUser, uiSettingsClient, }: { - queries: Array<{ text: string; boost?: number }>; + queries: string[]; categories?: string[]; user?: { name: string }; namespace: string; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx index 65ac65264f307..e39bcf5d1891e 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/components/chat/chat_body.test.tsx @@ -40,7 +40,7 @@ describe('', () => { role: 'assistant', function_call: { name: CONTEXT_FUNCTION_NAME, - arguments: '{}', + arguments: '{"queries":[],"categories":[]}', trigger: 'assistant', }, content: '', @@ -88,7 +88,7 @@ describe('', () => { role: 'assistant', function_call: { name: CONTEXT_FUNCTION_NAME, - arguments: '{}', + arguments: '{"queries":[],"categories":[]}', trigger: 'assistant', }, content: '', diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts index eb5ed07d3ea08..01f6e8cdd7bce 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts @@ -193,6 +193,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { role: MessageRole.Assistant, function_call: { name: 'context', + arguments: JSON.stringify({ queries: [], categories: [] }), trigger: MessageRole.Assistant, }, }, diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts index f496e42868ac8..ac2fa36f6b0fd 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/public_complete/public_complete.spec.ts @@ -72,7 +72,6 @@ export default function ApiTest({ getService }: FtrProviderContext) { format, }) .set('kbn-xsrf', 'foo') - .set('elastic-api-version', '2023-10-31') .send({ messages, connectorId, @@ -84,20 +83,13 @@ export default function ApiTest({ getService }: FtrProviderContext) { if (err) { return reject(err); } - if (response.status !== 200) { - return reject(new Error(`${response.status}: ${JSON.stringify(response.body)}`)); - } return resolve(response); }); }); - const [conversationSimulator, titleSimulator] = await Promise.race([ - Promise.all([ - conversationInterceptor.waitForIntercept(), - titleInterceptor.waitForIntercept(), - ]), - // make sure any request failures (like 400s) are properly propagated - responsePromise.then(() => []), + const [conversationSimulator, titleSimulator] = await Promise.all([ + conversationInterceptor.waitForIntercept(), + titleInterceptor.waitForIntercept(), ]); await titleSimulator.status(200); diff --git a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts index 3e766877c5bca..b7c33db0a4122 100644 --- a/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts +++ b/x-pack/test/observability_ai_assistant_functional/tests/conversations/index.spec.ts @@ -94,7 +94,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte content: '', function_call: { name: 'context', - arguments: '{}', + arguments: '{"queries":[],"categories":[]}', trigger: MessageRole.Assistant, }, }, @@ -290,6 +290,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({ name: 'context', + arguments: JSON.stringify({ queries: [], categories: [] }), }); expect(contextResponse.name).to.eql('context'); @@ -353,6 +354,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({ name: 'context', + arguments: JSON.stringify({ queries: [], categories: [] }), }); expect(contextResponse.name).to.eql('context');