Skip to content

Commit

Permalink
[Obs AI Assistant] Expose recall function as API (elastic#185058)
Browse files Browse the repository at this point in the history
Exposes a `POST /internal/observability_ai_assistant/chat/recall`
endpoint for [Investigate UI
](elastic#183293). It is mostly just
moving stuff around, some small refactorings and a new way to generate
short ids. Previously we were using indexes for scoring suggestions, we
are now generating a short but unique id (ie 4-5 chars) which generates
a fairly unique token which strengthens the relationship between the id
and the object but still allows for quick output. LLMs are slow to
generate UUIDs, but indexes are very generic and the LLM might not pay a
lot of attention to it.
  • Loading branch information
dgieselaar authored Jun 15, 2024
1 parent ee15561 commit 1338287
Show file tree
Hide file tree
Showing 35 changed files with 897 additions and 509 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
export type { Message, Conversation, KnowledgeBaseEntry } from './types';
export type { ConversationCreateRequest } from './types';
export { KnowledgeBaseEntryRole, MessageRole } from './types';
export type { FunctionDefinition } from './functions/types';
export type { FunctionDefinition, CompatibleJSONSchema } from './functions/types';
export { FunctionVisibility } from './functions/function_visibility';
export {
VISUALIZE_ESQL_USER_INTENTIONS,
Expand Down Expand Up @@ -49,3 +49,5 @@ export { concatenateChatCompletionChunks } from './utils/concatenate_chat_comple
export { DEFAULT_LANGUAGE_OPTION, LANGUAGE_OPTIONS } from './ui_settings/language_options';

export { isSupportedConnectorType } from './connectors';

export { ShortIdTable } from './utils/short_id_table';
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ export interface KnowledgeBaseEntry {
export interface UserInstruction {
doc_id: string;
text: string;
system?: boolean;
}

export type UserInstructionOrPlainText = string | UserInstruction;
Expand All @@ -109,15 +110,15 @@ export interface ObservabilityAIAssistantScreenContextRequest {
actions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>;
}

export type ScreenContextActionRespondFunction<TArguments extends unknown> = ({}: {
export type ScreenContextActionRespondFunction<TArguments> = ({}: {
args: TArguments;
signal: AbortSignal;
connectorId: string;
client: Pick<ObservabilityAIAssistantChatService, 'chat' | 'complete'>;
messages: Message[];
}) => Promise<FunctionResponse>;

export interface ScreenContextActionDefinition<TArguments = undefined> {
export interface ScreenContextActionDefinition<TArguments = any> {
name: string;
description: string;
parameters?: CompatibleJSONSchema;
Expand All @@ -137,6 +138,6 @@ export interface ObservabilityAIAssistantScreenContext {
description: string;
value: any;
}>;
actions?: ScreenContextActionDefinition[];
actions?: Array<ScreenContextActionDefinition<any>>;
starterPrompts?: StarterPrompt[];
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export const concatenateChatCompletionChunks =
acc.message.content += message.content ?? '';
acc.message.function_call.name += message.function_call?.name ?? '';
acc.message.function_call.arguments += message.function_call?.arguments ?? '';

return cloneDeep(acc);
},
{
Expand All @@ -43,6 +44,6 @@ export const concatenateChatCompletionChunks =
},
role: MessageRole.Assistant,
},
}
} as ConcatenatedMessage
)
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ShortIdTable } from './short_id_table';

describe('shortIdTable', () => {
it('generates at least 10k unique ids consistently', () => {
const ids = new Set();

const table = new ShortIdTable();

let i = 10_000;
while (i--) {
const id = table.take(String(i));
ids.add(id);
}

expect(ids.size).toBe(10_000);
});

it('returns the original id based on the generated id', () => {
const table = new ShortIdTable();

const idsByOriginal = new Map<string, string>();

let i = 100;
while (i--) {
const id = table.take(String(i));
idsByOriginal.set(String(i), id);
}

expect(idsByOriginal.size).toBe(100);

expect(() => {
Array.from(idsByOriginal.entries()).forEach(([originalId, shortId]) => {
const returnedOriginalId = table.lookup(shortId);
if (returnedOriginalId !== originalId) {
throw Error(
`Expected shortId ${shortId} to return ${originalId}, but ${returnedOriginalId} was returned instead`
);
}
});
}).not.toThrow();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

const ALPHABET = 'abcdefghijklmnopqrstuvwxyz';

function generateShortId(size: number): string {
let id = '';
let i = size;
while (i--) {
const index = Math.floor(Math.random() * ALPHABET.length);
id += ALPHABET[index];
}
return id;
}

const MAX_ATTEMPTS_AT_LENGTH = 100;

export class ShortIdTable {
private byShortId: Map<string, string> = new Map();
private byOriginalId: Map<string, string> = new Map();

constructor() {}

take(originalId: string) {
if (this.byOriginalId.has(originalId)) {
return this.byOriginalId.get(originalId)!;
}

let uniqueId: string | undefined;
let attemptsAtLength = 0;
let length = 4;
while (!uniqueId) {
const nextId = generateShortId(length);
attemptsAtLength++;
if (!this.byShortId.has(nextId)) {
uniqueId = nextId;
} else if (attemptsAtLength >= MAX_ATTEMPTS_AT_LENGTH) {
attemptsAtLength = 0;
length++;
}
}

this.byShortId.set(uniqueId, originalId);
this.byOriginalId.set(originalId, uniqueId);

return uniqueId;
}

lookup(shortId: string) {
return this.byShortId.get(shortId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export function throwSerializedChatCompletionErrors<
return (source$) =>
source$.pipe(
tap((event) => {
// de-serialise error
// de-serialize error
if (event.type === StreamingChatResponseEventType.ChatCompletionError) {
const code = event.error.code ?? ChatCompletionErrorCode.InternalError;
const message = event.error.message;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { Observable, OperatorFunction, takeUntil } from 'rxjs';
import { AbortError } from '@kbn/kibana-utils-plugin/common';

export function untilAborted<T>(signal: AbortSignal): OperatorFunction<T, T> {
return (source$) => {
const signal$ = new Observable((subscriber) => {
if (signal.aborted) {
subscriber.error(new AbortError());
}
signal.addEventListener('abort', () => {
subscriber.error(new AbortError());
});
});

return source$.pipe(takeUntil(signal$));
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface AssistantAvatarProps {
size?: keyof typeof sizeMap;
children?: ReactNode;
css?: React.SVGProps<SVGElement>['css'];
className?: string;
}

export const sizeMap = {
Expand All @@ -20,7 +21,7 @@ export const sizeMap = {
xs: 16,
};

export function AssistantAvatar({ size = 's', css }: AssistantAvatarProps) {
export function AssistantAvatar({ size = 's', css, className }: AssistantAvatarProps) {
const sizePx = sizeMap[size];
return (
<svg
Expand All @@ -30,6 +31,7 @@ export function AssistantAvatar({ size = 's', css }: AssistantAvatarProps) {
viewBox="0 0 64 64"
fill="none"
css={css}
className={className}
>
<path fill="#F04E98" d="M36 28h24v36H36V28Z" />
<path fill="#00BFB3" d="M4 46c0-9.941 8.059-18 18-18h6v36h-6c-9.941 0-18-8.059-18-18Z" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ export function useAbortableAsync<T>(

if (clearValueOnNext) {
setValue(undefined);
setError(undefined);
}

try {
const response = fn({ signal: controller.signal });
if (isPromise(response)) {
setLoading(true);
response
.then(setValue)
.then((nextValue) => {
setError(undefined);
setValue(nextValue);
})
.catch((err) => {
setValue(undefined);
setError(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
* 2.0.
*/
import type { PluginInitializer, PluginInitializerContext } from '@kbn/core/public';
export type { CompatibleJSONSchema } from '../common/functions/types';

import { ObservabilityAIAssistantPlugin } from './plugin';
import type {
Expand All @@ -18,6 +17,7 @@ import type {
ObservabilityAIAssistantChatService,
RegisterRenderFunctionDefinition,
RenderFunction,
DiscoveredDataset,
} from './types';

export type {
Expand All @@ -27,6 +27,7 @@ export type {
ObservabilityAIAssistantChatService,
RegisterRenderFunctionDefinition,
RenderFunction,
DiscoveredDataset,
};

export { aiAssistantCapabilities } from '../common/capabilities';
Expand Down Expand Up @@ -59,15 +60,27 @@ export {
VISUALIZE_ESQL_USER_INTENTIONS,
} from '../common/functions/visualize_esql';

export { isSupportedConnectorType } from '../common';
export { FunctionVisibility } from '../common';
export {
isSupportedConnectorType,
FunctionVisibility,
MessageRole,
KnowledgeBaseEntryRole,
concatenateChatCompletionChunks,
StreamingChatResponseEventType,
} from '../common';
export type {
CompatibleJSONSchema,
Conversation,
Message,
KnowledgeBaseEntry,
FunctionDefinition,
ChatCompletionChunkEvent,
ShortIdTable,
} from '../common';

export type { TelemetryEventTypeWithPayload } from './analytics';
export { ObservabilityAIAssistantTelemetryEventType } from './analytics/telemetry_event_type';

export type { Conversation, Message, KnowledgeBaseEntry } from '../common';
export { MessageRole, KnowledgeBaseEntryRole } from '../common';

export { createFunctionRequestMessage } from '../common/utils/create_function_request_message';
export { createFunctionResponseMessage } from '../common/utils/create_function_response_message';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import { i18n } from '@kbn/i18n';
import { noop } from 'lodash';
import React from 'react';
import { Observable, of } from 'rxjs';
import type { StreamingChatResponseEventWithoutError } from '../common/conversation_complete';
import type {
ChatCompletionChunkEvent,
StreamingChatResponseEventWithoutError,
} from '../common/conversation_complete';
import { MessageRole, ScreenContextActionDefinition } from '../common/types';
import type { ObservabilityAIAssistantAPIClient } from './api';
import type {
Expand All @@ -21,7 +24,7 @@ import { buildFunctionElasticsearch, buildFunctionServiceSummary } from './utils

export const mockChatService: ObservabilityAIAssistantChatService = {
sendAnalyticsEvent: noop,
chat: (options) => new Observable<StreamingChatResponseEventWithoutError>(),
chat: (options) => new Observable<ChatCompletionChunkEvent>(),
complete: (options) => new Observable<StreamingChatResponseEventWithoutError>(),
getFunctions: () => [buildFunctionElasticsearch(), buildFunctionServiceSummary()],
renderFunction: (name) => (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import {
ChatCompletionError,
MessageAddEvent,
createInternalServerError,
createConversationNotFoundError,
StreamingChatResponseEventWithoutError,
} from '../../common';
import type { ObservabilityAIAssistantChatService } from '../types';
import { complete } from './complete';
Expand Down Expand Up @@ -45,7 +47,7 @@ const messages: Message[] = [

const createLlmResponse = (
chunks: Array<{ content: string; function_call?: { name: string; arguments: string } }>
): StreamingChatResponseEvent[] => {
): StreamingChatResponseEventWithoutError[] => {
const id = v4();
const message = chunks.reduce<Message['message']>(
(prev, current) => {
Expand All @@ -61,7 +63,7 @@ const createLlmResponse = (
}
);

const events: StreamingChatResponseEvent[] = [
const events: StreamingChatResponseEventWithoutError[] = [
...chunks.map((msg) => ({
id,
message: msg,
Expand Down Expand Up @@ -108,20 +110,12 @@ describe('complete', () => {

describe('when an error is emitted', () => {
beforeEach(() => {
requestCallback.mockImplementation(() =>
of({
type: StreamingChatResponseEventType.ChatCompletionError,
error: {
message: 'Not found',
code: ChatCompletionErrorCode.NotFoundError,
},
})
);
requestCallback.mockImplementation(() => throwError(() => createConversationNotFoundError()));
});

it('the observable errors out', async () => {
await expect(async () => await lastValueFrom(callComplete())).rejects.toThrowError(
'Not found'
'Conversation not found'
);

await expect(async () => await lastValueFrom(callComplete())).rejects.toBeInstanceOf(
Expand Down
Loading

0 comments on commit 1338287

Please sign in to comment.