Skip to content

Commit

Permalink
[Obs AI Assistant] Include search-* when recalling documents (#173710)
Browse files Browse the repository at this point in the history
Include `search-*` indices when recalling documents from the knowledge
base. General approach:

- use the current user, not the internal user. the latter will ~never
have access to `search-*`
- use `_field_caps` to look for sparse_vector field types
- `ml.inference.` is a hard-coded prefix, so we can strip that and
`_expanded.predicted_value` to get the original field name
- only include documents that have the same model ID as we are using for
our regular recalls
- if the request fails for whatever reason (which is fine, users might
not have access to `search-*`), just ignore it and log it with log level
debug
- we serialize the entire document - some other non-vectorized metadata
can also be important for the LLM to make decisions
- sort all documents (kb + `search-*`) by score and return the first 20
- count the amount of tokens, don't send over more than 4000 tokens to
the LLM to keep response time down. drop the remaining documents on the
floor and log it.
  • Loading branch information
dgieselaar authored Dec 22, 2023
1 parent 31b7380 commit fc997b1
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export const registerFunctions: ChatRegistrationFunction = async ({
resources,
signal,
};

return client.getKnowledgeBaseStatus().then((response) => {
const isReady = response.ready;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@ import dedent from 'dedent';
import * as t from 'io-ts';
import { last, omit } from 'lodash';
import { lastValueFrom } from 'rxjs';
import { FunctionRegistrationParameters } from '.';
import { MessageRole, type Message } from '../../common/types';
import { concatenateOpenAiChunks } from '../../common/utils/concatenate_openai_chunks';
import { processOpenAiStream } from '../../common/utils/process_openai_stream';
import type { ObservabilityAIAssistantClient } from '../service/client';
import type { RegisterFunction } from '../service/types';
import { streamIntoObservable } from '../service/util/stream_into_observable';

export function registerRecallFunction({
client,
registerFunction,
}: {
client: ObservabilityAIAssistantClient;
registerFunction: RegisterFunction;
}) {
resources,
}: FunctionRegistrationParameters) {
registerFunction(
{
name: 'recall',
Expand Down Expand Up @@ -99,6 +97,10 @@ export function registerRecallFunction({
queries,
});

resources.logger.debug(`Received ${suggestions.length} suggestions`);

resources.logger.debug(JSON.stringify(suggestions, null, 2));

if (suggestions.length === 0) {
return {
content: [] as unknown as Serializable,
Expand All @@ -115,6 +117,9 @@ export function registerRecallFunction({
signal,
});

resources.logger.debug(`Received ${relevantDocuments.length} relevant documents`);
resources.logger.debug(JSON.stringify(relevantDocuments, null, 2));

return {
content: relevantDocuments as unknown as Serializable,
};
Expand Down Expand Up @@ -254,7 +259,6 @@ async function scoreSuggestions({
})
).pipe(processOpenAiStream(), concatenateOpenAiChunks())
);

const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);
const { scores } = decodeOrThrow(jsonRt.pipe(scoreFunctionArgumentsRt))(
scoreFunctionRequest.message.function_call.arguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,23 @@ describe('Observability AI Assistant service', () => {
execute: jest.fn(),
} as any;

const esClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
const internalUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
search: jest.fn(),
index: jest.fn(),
update: jest.fn(),
} as any;

const currentUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
search: jest.fn().mockResolvedValue({
hits: {
hits: [],
},
}),
fieldCaps: jest.fn().mockResolvedValue({
fields: [],
}),
} as any;

const knowledgeBaseServiceMock: DeeplyMockedKeys<KnowledgeBaseService> = {
recall: jest.fn(),
} as any;
Expand All @@ -91,6 +102,7 @@ describe('Observability AI Assistant service', () => {
log: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
trace: jest.fn(),
} as any;

const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
Expand All @@ -108,7 +120,10 @@ describe('Observability AI Assistant service', () => {

return new ObservabilityAIAssistantClient({
actionsClient: actionsClientMock,
esClient: esClientMock,
esClient: {
asInternalUser: internalUserEsClientMock,
asCurrentUser: currentUserEsClientMock,
},
knowledgeBaseService: knowledgeBaseServiceMock,
logger: loggerMock,
namespace: 'default',
Expand Down Expand Up @@ -334,7 +349,7 @@ describe('Observability AI Assistant service', () => {
type: StreamingChatResponseEventType.ConversationCreate,
});

expect(esClientMock.index).toHaveBeenCalledWith({
expect(internalUserEsClientMock.index).toHaveBeenCalledWith({
index: '.kibana-observability-ai-assistant-conversations',
refresh: true,
document: {
Expand Down Expand Up @@ -386,7 +401,7 @@ describe('Observability AI Assistant service', () => {
});
});

describe('when completig a conversation with an initial conversation id', () => {
describe('when completing a conversation with an initial conversation id', () => {
let stream: Readable;

let dataHandler: jest.Mock;
Expand All @@ -402,7 +417,7 @@ describe('Observability AI Assistant service', () => {
};
});

esClientMock.search.mockImplementation(async () => {
internalUserEsClientMock.search.mockImplementation(async () => {
return {
hits: {
hits: [
Expand Down Expand Up @@ -430,7 +445,7 @@ describe('Observability AI Assistant service', () => {
} as any;
});

esClientMock.update.mockImplementationOnce(async () => {
internalUserEsClientMock.update.mockImplementationOnce(async () => {
return {} as any;
});

Expand Down Expand Up @@ -464,7 +479,7 @@ describe('Observability AI Assistant service', () => {
type: StreamingChatResponseEventType.ConversationUpdate,
});

expect(esClientMock.update).toHaveBeenCalledWith({
expect(internalUserEsClientMock.update).toHaveBeenCalledWith({
refresh: true,
index: '.kibana-observability-ai-assistant-conversations',
id: 'my-es-document-id',
Expand Down Expand Up @@ -573,8 +588,8 @@ describe('Observability AI Assistant service', () => {
});

it('does not create or update the conversation', async () => {
expect(esClientMock.index).not.toHaveBeenCalled();
expect(esClientMock.update).not.toHaveBeenCalled();
expect(internalUserEsClientMock.index).not.toHaveBeenCalled();
expect(internalUserEsClientMock.update).not.toHaveBeenCalled();
});
});

Expand Down Expand Up @@ -816,9 +831,11 @@ describe('Observability AI Assistant service', () => {
},
});

expect(esClientMock.index).toHaveBeenCalled();
expect(internalUserEsClientMock.index).toHaveBeenCalled();

expect((esClientMock.index.mock.lastCall![0] as any).document.messages).toEqual([
expect(
(internalUserEsClientMock.index.mock.lastCall![0] as any).document.messages
).toEqual([
{
'@timestamp': expect.any(String),
message: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ export class ObservabilityAIAssistantClient {
private readonly dependencies: {
actionsClient: PublicMethodsOf<ActionsClient>;
namespace: string;
esClient: ElasticsearchClient;
esClient: {
asInternalUser: ElasticsearchClient;
asCurrentUser: ElasticsearchClient;
};
resources: ObservabilityAIAssistantResourceNames;
logger: Logger;
user: {
Expand All @@ -67,7 +70,7 @@ export class ObservabilityAIAssistantClient {
private getConversationWithMetaFields = async (
conversationId: string
): Promise<SearchHit<Conversation> | undefined> => {
const response = await this.dependencies.esClient.search<Conversation>({
const response = await this.dependencies.esClient.asInternalUser.search<Conversation>({
index: this.dependencies.resources.aliases.conversations,
query: {
bool: {
Expand Down Expand Up @@ -113,7 +116,7 @@ export class ObservabilityAIAssistantClient {
throw notFound();
}

await this.dependencies.esClient.delete({
await this.dependencies.esClient.asInternalUser.delete({
id: conversation._id,
index: conversation._index,
refresh: true,
Expand Down Expand Up @@ -407,7 +410,7 @@ export class ObservabilityAIAssistantClient {
};

this.dependencies.logger.debug(`Sending conversation to connector`);
this.dependencies.logger.debug(JSON.stringify(request, null, 2));
this.dependencies.logger.trace(JSON.stringify(request, null, 2));

const executeResult = await this.dependencies.actionsClient.execute({
actionId: connectorId,
Expand All @@ -428,17 +431,15 @@ export class ObservabilityAIAssistantClient {
? (executeResult.data as Readable)
: (executeResult.data as CreateChatCompletionResponse);

if (response instanceof PassThrough) {
signal.addEventListener('abort', () => {
response.end();
});
if (response instanceof Readable) {
signal.addEventListener('abort', () => response.destroy());
}

return response as any;
};

find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => {
const response = await this.dependencies.esClient.search<Conversation>({
const response = await this.dependencies.esClient.asInternalUser.search<Conversation>({
index: this.dependencies.resources.aliases.conversations,
allow_no_indices: true,
query: {
Expand Down Expand Up @@ -475,7 +476,7 @@ export class ObservabilityAIAssistantClient {
this.getConversationUpdateValues(new Date().toISOString())
);

await this.dependencies.esClient.update({
await this.dependencies.esClient.asInternalUser.update({
id: document._id,
index: document._index,
doc: updatedConversation,
Expand Down Expand Up @@ -547,7 +548,7 @@ export class ObservabilityAIAssistantClient {
this.getConversationUpdateValues(new Date().toISOString())
);

await this.dependencies.esClient.update({
await this.dependencies.esClient.asInternalUser.update({
id: document._id,
index: document._index,
doc: { conversation: { title } },
Expand All @@ -570,7 +571,7 @@ export class ObservabilityAIAssistantClient {
this.getConversationUpdateValues(now)
);

await this.dependencies.esClient.index({
await this.dependencies.esClient.asInternalUser.index({
index: this.dependencies.resources.aliases.conversations,
document: createdConversation,
refresh: true,
Expand All @@ -591,6 +592,7 @@ export class ObservabilityAIAssistantClient {
user: this.dependencies.user,
queries,
contexts,
asCurrentUser: this.dependencies.esClient.asCurrentUser,
});
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ export class ObservabilityAIAssistantService {
return new ObservabilityAIAssistantClient({
actionsClient: await plugins.actions.getActionsClientWithRequest(request),
namespace: spaceId,
esClient: coreStart.elasticsearch.client.asInternalUser,
esClient: {
asInternalUser: coreStart.elasticsearch.client.asInternalUser,
asCurrentUser: coreStart.elasticsearch.client.asScoped(request).asCurrentUser,
},
resources: this.resourceNames,
logger: this.logger,
user: {
Expand Down
Loading

0 comments on commit fc997b1

Please sign in to comment.