Skip to content

Commit

Permalink
[Obs AI Assistant] Add ES function API test (elastic#187465)
Browse files Browse the repository at this point in the history
Related to elastic#180787
  • Loading branch information
sorenlouv authored Jul 6, 2024
1 parent 8ccd7b3 commit 4504088
Show file tree
Hide file tree
Showing 15 changed files with 375 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import type { FunctionDefinition } from '../../common/functions/types';

export function buildFunction(): FunctionDefinition {
export function buildFunctionElasticsearch(): FunctionDefinition {
return {
name: 'elasticsearch',
description: 'Call Elasticsearch APIs on behalf of the user',
Expand All @@ -30,8 +30,6 @@ export function buildFunction(): FunctionDefinition {
};
}

export const buildFunctionElasticsearch = buildFunction;

export function buildFunctionServiceSummary(): FunctionDefinition {
return {
name: 'get_service_summary',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

import type { FunctionRegistrationParameters } from '.';

export const ELASTICSEARCH_FUNCTION_NAME = 'elasticsearch';

export function registerElasticsearchFunction({
functions,
resources,
}: FunctionRegistrationParameters) {
functions.registerFunction(
{
name: 'elasticsearch',
name: ELASTICSEARCH_FUNCTION_NAME,
description:
'Call Elasticsearch APIs on behalf of the user. Make sure the request body is valid for the API that you are using. Only call this function when the user has explicitly requested it.',
descriptionForUser: 'Call Elasticsearch APIs on behalf of the user',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { findLastIndex } from 'lodash';
import { findLastIndex, last } from 'lodash';
import { Message, MessageAddEvent, MessageRole } from '../../../common';
import { createFunctionRequestMessage } from '../../../common/utils/create_function_request_message';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
Expand All @@ -22,11 +22,10 @@ export function getContextFunctionRequestIfNeeded(
.slice(indexOfLastUserMessage)
.some((message) => message.message.name === CONTEXT_FUNCTION_NAME);

if (hasContextSinceLastUserMessage) {
const isLastMessageFunctionRequest = !!last(messages)?.message.function_call?.name;
if (hasContextSinceLastUserMessage || isLastMessageFunctionRequest) {
return undefined;
}

return createFunctionRequestMessage({
name: CONTEXT_FUNCTION_NAME,
});
return createFunctionRequestMessage({ name: CONTEXT_FUNCTION_NAME });
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import type { ElasticsearchClient, IUiSettingsClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { PublicMethodsOf } from '@kbn/utility-types';
import { SpanKind, context } from '@opentelemetry/api';
import { merge, omit } from 'lodash';
import { last, merge, omit } from 'lodash';
import {
catchError,
combineLatest,
Expand Down Expand Up @@ -334,13 +334,12 @@ export class ObservabilityAIAssistantClient {
const initialMessagesWithAddedMessages =
messagesWithUpdatedSystemMessage.concat(addedMessages);

const lastMessage =
initialMessagesWithAddedMessages[initialMessagesWithAddedMessages.length - 1];
const lastMessage = last(initialMessagesWithAddedMessages);

// if a function request is at the very end, close the stream to consumer
// without persisting or updating the conversation. we need to wait
// on the function response to have a valid conversation
const isFunctionRequest = lastMessage.message.function_call?.name;
const isFunctionRequest = !!lastMessage?.message.function_call?.name;

if (!persist || isFunctionRequest) {
return of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import { Logger } from '@kbn/logging';
import { decode, encode } from 'gpt-tokenizer';
import { pick, take } from 'lodash';
import { last, pick, take } from 'lodash';
import {
catchError,
concat,
Expand Down Expand Up @@ -212,10 +212,8 @@ export function continueConversation({
initialMessages
);

const lastMessage =
messagesWithUpdatedSystemMessage[messagesWithUpdatedSystemMessage.length - 1].message;

const isUserMessage = lastMessage.role === MessageRole.User;
const lastMessage = last(messagesWithUpdatedSystemMessage)?.message;
const isUserMessage = lastMessage?.role === MessageRole.User;

return executeNextStep().pipe(handleEvents());

Expand All @@ -233,7 +231,7 @@ export function continueConversation({
}).pipe(emitWithConcatenatedMessage(), catchFunctionNotFoundError(functionLimitExceeded));
}

const functionCallName = lastMessage.function_call?.name;
const functionCallName = lastMessage?.function_call?.name;

if (!functionCallName) {
// reply from the LLM without a function request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { runAndValidateEsqlQuery } from './validate_esql_query';
import { INLINE_ESQL_QUERY_REGEX } from './constants';

export const QUERY_FUNCTION_NAME = 'query';
export const EXECUTE_QUERY_NAME = 'execute_query';

const readFile = promisify(Fs.readFile);
const readdir = promisify(Fs.readdir);
Expand Down Expand Up @@ -89,13 +90,13 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
even if it has been called before.
When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt.
If the "execute_query" function has been called, summarize these results for the user. The user does not see a visualization in this case.`
If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.`
: undefined
);

functions.registerFunction(
{
name: 'execute_query',
name: EXECUTE_QUERY_NAME,
visibility: FunctionVisibility.UserOnly,
description: 'Display the results of an ES|QL query.',
parameters: {
Expand Down Expand Up @@ -365,7 +366,7 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: `Answer the user's question that was previously asked ("${abbreviatedUserQuestion}...") using the attached documentation. Take into account any previous errors from the \`execute_query\` or \`visualize_query\` function.
content: `Answer the user's question that was previously asked ("${abbreviatedUserQuestion}...") using the attached documentation. Take into account any previous errors from the \`${EXECUTE_QUERY_NAME}\` or \`visualize_query\` function.
Format any ES|QL query as follows:
\`\`\`esql
Expand Down Expand Up @@ -449,7 +450,7 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
functionCall = undefined;
} else if (args.intention === VisualizeESQLUserIntention.executeAndReturnResults) {
functionCall = {
name: 'execute_query',
name: EXECUTE_QUERY_NAME,
arguments: JSON.stringify({ query: esqlQuery }),
trigger: MessageRole.Assistant as const,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ import { UrlObject } from 'url';
import { ObservabilityAIAssistantFtrConfigName } from '../configs';
import { getApmSynthtraceEsClient } from './create_synthtrace_client';
import { InheritedFtrProviderContext, InheritedServices } from './ftr_provider_context';
import {
getScopedApiClient,
ObservabilityAIAssistantAPIClient,
} from './observability_ai_assistant_api_client';
import { getScopedApiClient } from './observability_ai_assistant_api_client';
import { editorUser, viewerUser } from './users/users';

export interface ObservabilityAIAssistantFtrConfig {
Expand All @@ -24,20 +21,13 @@ export interface ObservabilityAIAssistantFtrConfig {

export type CreateTestConfig = ReturnType<typeof createTestConfig>;

export interface CreateTest {
testFiles: string[];
servers: any;
services: InheritedServices & {
observabilityAIAssistantAPIClient: () => Promise<{
adminUser: ObservabilityAIAssistantAPIClient;
viewerUser: ObservabilityAIAssistantAPIClient;
editorUser: ObservabilityAIAssistantAPIClient;
}>;
};
junit: { reportName: string };
esTestCluster: any;
kbnTestServer: any;
}
export type CreateTest = ReturnType<typeof createObservabilityAIAssistantAPIConfig>;

export type ObservabilityAIAssistantAPIClient = Awaited<
ReturnType<CreateTest['services']['observabilityAIAssistantAPIClient']>
>;

export type ObservabilityAIAssistantServices = Awaited<ReturnType<CreateTestConfig>>['services'];

export function createObservabilityAIAssistantAPIConfig({
config,
Expand All @@ -49,14 +39,15 @@ export function createObservabilityAIAssistantAPIConfig({
license: 'basic' | 'trial';
name: string;
kibanaConfig?: Record<string, any>;
}): Omit<CreateTest, 'testFiles'> {
}) {
const services = config.get('services') as InheritedServices;
const servers = config.get('servers');
const kibanaServer = servers.kibana as UrlObject;
const apmSynthtraceKibanaClient = services.apmSynthtraceKibanaClient();
const allConfigs = config.getAll() as Record<string, any>;

const createTest: Omit<CreateTest, 'testFiles'> = {
...config.getAll(),
return {
...allConfigs,
servers,
services: {
...services,
Expand Down Expand Up @@ -89,13 +80,9 @@ export function createObservabilityAIAssistantAPIConfig({
],
},
};

return createTest;
}

export function createTestConfig(
config: ObservabilityAIAssistantFtrConfig
): ({ readConfigFile }: FtrConfigProviderContext) => Promise<CreateTest> {
export function createTestConfig(config: ObservabilityAIAssistantFtrConfig) {
const { license, name, kibanaConfig } = config;

return async ({ readConfigFile }: FtrConfigProviderContext) => {
Expand All @@ -114,5 +101,3 @@ export function createTestConfig(
};
};
}

export type ObservabilityAIAssistantServices = Awaited<ReturnType<CreateTestConfig>>['services'];
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ export class LlmProxy {
waitForIntercept: () => Promise<LlmResponseSimulator>;
}
: {
waitAndComplete: () => Promise<void>;
completeAfterIntercept: () => Promise<void>;
} {
const waitForInterceptPromise = Promise.race([
new Promise<LlmResponseSimulator>((outerResolve) => {
Expand Down Expand Up @@ -162,7 +162,7 @@ export class LlmProxy {
: responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`));

return {
waitAndComplete: async () => {
completeAfterIntercept: async () => {
const simulator = await waitForInterceptPromise;
for (const chunk of parsedChunks) {
await simulator.next(chunk);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,11 @@ export default function ApiTest({ getService }: FtrProviderContext) {
},
},
])
.waitAndComplete();
.completeAfterIntercept();

proxy
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good morning, sir!')
.waitAndComplete();
.completeAfterIntercept();

const createResponse = await observabilityAIAssistantAPIClient
.editorUser({
Expand Down Expand Up @@ -450,7 +450,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {

proxy
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good night, sir!')
.waitAndComplete();
.completeAfterIntercept();

const updatedResponse = await observabilityAIAssistantAPIClient
.editorUser({
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import expect from '@kbn/expect';
import { apm, timerange } from '@kbn/apm-synthtrace-client';
import { ApmSynthtraceEsClient } from '@kbn/apm-synthtrace';
import { ELASTICSEARCH_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/elasticsearch';
import { LlmProxy } from '../../../common/create_llm_proxy';
import { FtrProviderContext } from '../../../common/ftr_provider_context';
import {
createLLMProxyConnector,
deleteLLMProxyConnector,
getMessageAddedEvents,
invokeChatCompleteWithFunctionRequest,
} from './helpers';

export default function ApiTest({ getService }: FtrProviderContext) {
const supertest = getService('supertest');
const log = getService('log');
const apmSynthtraceEsClient = getService('apmSynthtraceEsClient');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantAPIClient');

describe('when calling elasticsearch', () => {
let proxy: LlmProxy;
let connectorId: string;
let events: MessageAddEvent[];

before(async () => {
({ connectorId, proxy } = await createLLMProxyConnector({ log, supertest }));
await generateApmData(apmSynthtraceEsClient);

const responseBody = await invokeChatCompleteWithFunctionRequest({
connectorId,
observabilityAIAssistantAPIClient,
functionCall: {
name: ELASTICSEARCH_FUNCTION_NAME,
trigger: MessageRole.User,
arguments: JSON.stringify({
method: 'POST',
path: 'traces*/_search',
body: {
size: 0,
aggs: {
services: {
terms: {
field: 'service.name',
},
},
},
},
}),
},
});

await proxy.waitForAllInterceptorsSettled();

events = getMessageAddedEvents(responseBody);
});

after(async () => {
await deleteLLMProxyConnector({ supertest, connectorId, proxy });
await apmSynthtraceEsClient.clean();
});

it('returns elasticsearch function response', async () => {
const esFunctionResponse = events[0];
const parsedEsResponse = JSON.parse(esFunctionResponse.message.message.content!).response;

expect(esFunctionResponse.message.message.name).to.be('elasticsearch');
expect(parsedEsResponse.hits.total.value).to.be(15);
expect(parsedEsResponse.aggregations.services.buckets).to.eql([
{ key: 'java-backend', doc_count: 15 },
]);
expect(events.length).to.be(2);
});
});
}

export async function generateApmData(apmSynthtraceEsClient: ApmSynthtraceEsClient) {
const serviceA = apm
.service({ name: 'java-backend', environment: 'production', agentName: 'java' })
.instance('a');

const events = timerange('now-15m', 'now')
.interval('1m')
.rate(1)
.generator((timestamp) => {
return serviceA.transaction({ transactionName: 'tx' }).timestamp(timestamp).duration(1000);
});

return apmSynthtraceEsClient.index(events);
}
Loading

0 comments on commit 4504088

Please sign in to comment.