Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Obs AI Assistant] Add ES function API test #187465

Merged
merged 11 commits into from
Jul 6, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import type { FunctionDefinition } from '../../common/functions/types';

export function buildFunction(): FunctionDefinition {
export function buildFunctionElasticsearch(): FunctionDefinition {
return {
name: 'elasticsearch',
description: 'Call Elasticsearch APIs on behalf of the user',
Expand All @@ -30,8 +30,6 @@ export function buildFunction(): FunctionDefinition {
};
}

export const buildFunctionElasticsearch = buildFunction;

export function buildFunctionServiceSummary(): FunctionDefinition {
return {
name: 'get_service_summary',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

import type { FunctionRegistrationParameters } from '.';

export const ELASTICSEARCH_FUNCTION_NAME = 'elasticsearch';

export function registerElasticsearchFunction({
functions,
resources,
}: FunctionRegistrationParameters) {
functions.registerFunction(
{
name: 'elasticsearch',
name: ELASTICSEARCH_FUNCTION_NAME,
description:
'Call Elasticsearch APIs on behalf of the user. Make sure the request body is valid for the API that you are using. Only call this function when the user has explicitly requested it.',
descriptionForUser: 'Call Elasticsearch APIs on behalf of the user',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { findLastIndex } from 'lodash';
import { findLastIndex, last } from 'lodash';
import { Message, MessageAddEvent, MessageRole } from '../../../common';
import { createFunctionRequestMessage } from '../../../common/utils/create_function_request_message';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
Expand All @@ -22,11 +22,10 @@ export function getContextFunctionRequestIfNeeded(
.slice(indexOfLastUserMessage)
.some((message) => message.message.name === CONTEXT_FUNCTION_NAME);

if (hasContextSinceLastUserMessage) {
const isLastMessageFunctionRequest = !!last(messages)?.message.function_call?.name;
if (hasContextSinceLastUserMessage || isLastMessageFunctionRequest) {
Copy link
Member Author

@sorenlouv sorenlouv Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dgieselaar Does this change look ok to you? Before we'd inject the context request immediately after other function requests effectively overwriting them. This changes fixes that but I'm not sure if it causes problems to context function request injection in some cases

return undefined;
}

return createFunctionRequestMessage({
name: CONTEXT_FUNCTION_NAME,
});
return createFunctionRequestMessage({ name: CONTEXT_FUNCTION_NAME });
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import type { ElasticsearchClient, IUiSettingsClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { PublicMethodsOf } from '@kbn/utility-types';
import { SpanKind, context } from '@opentelemetry/api';
import { merge, omit } from 'lodash';
import { last, merge, omit } from 'lodash';
import {
catchError,
combineLatest,
Expand Down Expand Up @@ -334,13 +334,12 @@ export class ObservabilityAIAssistantClient {
const initialMessagesWithAddedMessages =
messagesWithUpdatedSystemMessage.concat(addedMessages);

const lastMessage =
initialMessagesWithAddedMessages[initialMessagesWithAddedMessages.length - 1];
const lastMessage = last(initialMessagesWithAddedMessages);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use utility for improved readability


// if a function request is at the very end, close the stream to consumer
// without persisting or updating the conversation. we need to wait
// on the function response to have a valid conversation
const isFunctionRequest = lastMessage.message.function_call?.name;
const isFunctionRequest = !!lastMessage?.message.function_call?.name;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast to bool


if (!persist || isFunctionRequest) {
return of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import { Logger } from '@kbn/logging';
import { decode, encode } from 'gpt-tokenizer';
import { pick, take } from 'lodash';
import { last, pick, take } from 'lodash';
import {
catchError,
concat,
Expand Down Expand Up @@ -212,10 +212,8 @@ export function continueConversation({
initialMessages
);

const lastMessage =
messagesWithUpdatedSystemMessage[messagesWithUpdatedSystemMessage.length - 1].message;

const isUserMessage = lastMessage.role === MessageRole.User;
const lastMessage = last(messagesWithUpdatedSystemMessage)?.message;
const isUserMessage = lastMessage?.role === MessageRole.User;

return executeNextStep().pipe(handleEvents());

Expand All @@ -233,7 +231,7 @@ export function continueConversation({
}).pipe(emitWithConcatenatedMessage(), catchFunctionNotFoundError(functionLimitExceeded));
}

const functionCallName = lastMessage.function_call?.name;
const functionCallName = lastMessage?.function_call?.name;

if (!functionCallName) {
// reply from the LLM without a function request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import { runAndValidateEsqlQuery } from './validate_esql_query';
import { INLINE_ESQL_QUERY_REGEX } from './constants';

export const QUERY_FUNCTION_NAME = 'query';
export const EXECUTE_QUERY_NAME = 'execute_query';

const readFile = promisify(Fs.readFile);
const readdir = promisify(Fs.readdir);
Expand Down Expand Up @@ -89,13 +90,13 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
even if it has been called before.

When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt.
If the "execute_query" function has been called, summarize these results for the user. The user does not see a visualization in this case.`
If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.`
: undefined
);

functions.registerFunction(
{
name: 'execute_query',
name: EXECUTE_QUERY_NAME,
visibility: FunctionVisibility.UserOnly,
description: 'Display the results of an ES|QL query.',
parameters: {
Expand Down Expand Up @@ -365,7 +366,7 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: `Answer the user's question that was previously asked ("${abbreviatedUserQuestion}...") using the attached documentation. Take into account any previous errors from the \`execute_query\` or \`visualize_query\` function.
content: `Answer the user's question that was previously asked ("${abbreviatedUserQuestion}...") using the attached documentation. Take into account any previous errors from the \`${EXECUTE_QUERY_NAME}\` or \`visualize_query\` function.

Format any ES|QL query as follows:
\`\`\`esql
Expand Down Expand Up @@ -449,7 +450,7 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
functionCall = undefined;
} else if (args.intention === VisualizeESQLUserIntention.executeAndReturnResults) {
functionCall = {
name: 'execute_query',
name: EXECUTE_QUERY_NAME,
arguments: JSON.stringify({ query: esqlQuery }),
trigger: MessageRole.Assistant as const,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ import { UrlObject } from 'url';
import { ObservabilityAIAssistantFtrConfigName } from '../configs';
import { getApmSynthtraceEsClient } from './create_synthtrace_client';
import { InheritedFtrProviderContext, InheritedServices } from './ftr_provider_context';
import {
getScopedApiClient,
ObservabilityAIAssistantAPIClient,
} from './observability_ai_assistant_api_client';
import { getScopedApiClient } from './observability_ai_assistant_api_client';
import { editorUser, viewerUser } from './users/users';

export interface ObservabilityAIAssistantFtrConfig {
Expand All @@ -24,20 +21,13 @@ export interface ObservabilityAIAssistantFtrConfig {

export type CreateTestConfig = ReturnType<typeof createTestConfig>;

export interface CreateTest {
testFiles: string[];
servers: any;
services: InheritedServices & {
observabilityAIAssistantAPIClient: () => Promise<{
adminUser: ObservabilityAIAssistantAPIClient;
viewerUser: ObservabilityAIAssistantAPIClient;
editorUser: ObservabilityAIAssistantAPIClient;
}>;
};
junit: { reportName: string };
esTestCluster: any;
kbnTestServer: any;
}
export type CreateTest = ReturnType<typeof createObservabilityAIAssistantAPIConfig>;

export type ObservabilityAIAssistantAPIClient = Awaited<
ReturnType<CreateTest['services']['observabilityAIAssistantAPIClient']>
>;

export type ObservabilityAIAssistantServices = Awaited<ReturnType<CreateTestConfig>>['services'];

export function createObservabilityAIAssistantAPIConfig({
config,
Expand All @@ -49,14 +39,15 @@ export function createObservabilityAIAssistantAPIConfig({
license: 'basic' | 'trial';
name: string;
kibanaConfig?: Record<string, any>;
}): Omit<CreateTest, 'testFiles'> {
}) {
const services = config.get('services') as InheritedServices;
const servers = config.get('servers');
const kibanaServer = servers.kibana as UrlObject;
const apmSynthtraceKibanaClient = services.apmSynthtraceKibanaClient();
const allConfigs = config.getAll() as Record<string, any>;

const createTest: Omit<CreateTest, 'testFiles'> = {
...config.getAll(),
return {
...allConfigs,
servers,
services: {
...services,
Expand Down Expand Up @@ -89,13 +80,9 @@ export function createObservabilityAIAssistantAPIConfig({
],
},
};

return createTest;
}

export function createTestConfig(
config: ObservabilityAIAssistantFtrConfig
): ({ readConfigFile }: FtrConfigProviderContext) => Promise<CreateTest> {
export function createTestConfig(config: ObservabilityAIAssistantFtrConfig) {
const { license, name, kibanaConfig } = config;

return async ({ readConfigFile }: FtrConfigProviderContext) => {
Expand All @@ -114,5 +101,3 @@ export function createTestConfig(
};
};
}

export type ObservabilityAIAssistantServices = Awaited<ReturnType<CreateTestConfig>>['services'];
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ export class LlmProxy {
waitForIntercept: () => Promise<LlmResponseSimulator>;
}
: {
waitAndComplete: () => Promise<void>;
completeAfterIntercept: () => Promise<void>;
} {
const waitForInterceptPromise = Promise.race([
new Promise<LlmResponseSimulator>((outerResolve) => {
Expand Down Expand Up @@ -162,7 +162,7 @@ export class LlmProxy {
: responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`));

return {
waitAndComplete: async () => {
completeAfterIntercept: async () => {
const simulator = await waitForInterceptPromise;
for (const chunk of parsedChunks) {
await simulator.next(chunk);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,11 @@ export default function ApiTest({ getService }: FtrProviderContext) {
},
},
])
.waitAndComplete();
.completeAfterIntercept();

proxy
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good morning, sir!')
.waitAndComplete();
.completeAfterIntercept();

const createResponse = await observabilityAIAssistantAPIClient
.editorUser({
Expand Down Expand Up @@ -450,7 +450,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {

proxy
.intercept('conversation', (body) => !isFunctionTitleRequest(body), 'Good night, sir!')
.waitAndComplete();
.completeAfterIntercept();

const updatedResponse = await observabilityAIAssistantAPIClient
.editorUser({
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import expect from '@kbn/expect';
import { apm, timerange } from '@kbn/apm-synthtrace-client';
import { ApmSynthtraceEsClient } from '@kbn/apm-synthtrace';
import { ELASTICSEARCH_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/functions/elasticsearch';
import { LlmProxy } from '../../../common/create_llm_proxy';
import { FtrProviderContext } from '../../../common/ftr_provider_context';
import {
createLLMProxyConnector,
deleteLLMProxyConnector,
getMessageAddedEvents,
invokeChatCompleteWithFunctionRequest,
} from './helpers';

export default function ApiTest({ getService }: FtrProviderContext) {
const supertest = getService('supertest');
const log = getService('log');
const apmSynthtraceEsClient = getService('apmSynthtraceEsClient');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantAPIClient');

describe('when calling elasticsearch', () => {
let proxy: LlmProxy;
let connectorId: string;
let events: MessageAddEvent[];

before(async () => {
({ connectorId, proxy } = await createLLMProxyConnector({ log, supertest }));
await generateApmData(apmSynthtraceEsClient);

const responseBody = await invokeChatCompleteWithFunctionRequest({
connectorId,
observabilityAIAssistantAPIClient,
functionCall: {
name: ELASTICSEARCH_FUNCTION_NAME,
trigger: MessageRole.User,
arguments: JSON.stringify({
method: 'POST',
path: 'traces*/_search',
body: {
size: 0,
aggs: {
services: {
terms: {
field: 'service.name',
},
},
},
},
}),
},
});

await proxy.waitForAllInterceptorsSettled();

events = getMessageAddedEvents(responseBody);
});

after(async () => {
await deleteLLMProxyConnector({ supertest, connectorId, proxy });
await apmSynthtraceEsClient.clean();
});

it('returns elasticsearch function response', async () => {
const esFunctionResponse = events[0];
const parsedEsResponse = JSON.parse(esFunctionResponse.message.message.content!).response;

expect(esFunctionResponse.message.message.name).to.be('elasticsearch');
expect(parsedEsResponse.hits.total.value).to.be(15);
expect(parsedEsResponse.aggregations.services.buckets).to.eql([
{ key: 'java-backend', doc_count: 15 },
]);
expect(events.length).to.be(2);
});
});
}

export async function generateApmData(apmSynthtraceEsClient: ApmSynthtraceEsClient) {
const serviceA = apm
.service({ name: 'java-backend', environment: 'production', agentName: 'java' })
.instance('a');

const events = timerange('now-15m', 'now')
.interval('1m')
.rate(1)
.generator((timestamp) => {
return serviceA.transaction({ transactionName: 'tx' }).timestamp(timestamp).duration(1000);
});

return apmSynthtraceEsClient.index(events);
}
Loading