Skip to content

Commit

Permalink
[Obs AI Assistant] Fix auto-generation of titles (#182923)
Browse files Browse the repository at this point in the history
If the LLM doesn't send the title_conversation in a single go,
`getGeneratedTitle` fails, because it doesn't wait until the message has
been fully completed. This PR fixes that issue, and adds tests, and also
improves quote matching.
  • Loading branch information
dgieselaar authored May 8, 2024
1 parent 71784ad commit 27fdb2d
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, lastValueFrom, of, throwError, toArray } from 'rxjs';
import {
ChatCompletionChunkEvent,
Message,
MessageRole,
StreamingChatResponseEventType,
} from '../../../../common';
import { ChatEvent } from '../../../../common/conversation_complete';
import { getGeneratedTitle } from './get_generated_title';

describe('getGeneratedTitle', () => {
const messages: Message[] = [
{
'@timestamp': new Date().toISOString(),
message: {
content: 'A message',
role: MessageRole.User,
},
},
];

function createChatCompletionChunk(
content: string | { content?: string; function_call?: { name: string; arguments: string } }
): ChatCompletionChunkEvent {
const msg = typeof content === 'string' ? { content } : content;

return {
type: StreamingChatResponseEventType.ChatCompletionChunk,
id: 'id',
message: msg,
};
}

function callGenerateTitle(
...rest: [ChatEvent[]] | [{ responseLanguage?: string }, ChatEvent[]]
) {
const options = rest.length === 1 ? {} : rest[0];
const chunks = rest.length === 1 ? rest[0] : rest[1];

const chatSpy = jest.fn().mockImplementation(() => of(...chunks));

const title$ = getGeneratedTitle({
chat: chatSpy,
logger: {
debug: jest.fn(),
error: jest.fn(),
},
messages,
...options,
});

return { chatSpy, title$ };
}

it('returns the given title as a string', async () => {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
}),
]);

const title = await lastValueFrom(
title$.pipe(filter((event): event is string => typeof event === 'string'))
);

expect(title).toEqual('My title');
});

it('calls chat with the user message', async () => {
const { chatSpy, title$ } = callGenerateTitle([
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
}),
]);

await lastValueFrom(title$);

const [name, params] = chatSpy.mock.calls[0];

expect(name).toEqual('generate_title');
expect(params.messages.length).toBe(2);
expect(params.messages[1].message.content).toContain('A message');
});

it('strips quotes from the title', async () => {
async function testTitle(title: string) {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title }),
},
}),
]);

return await lastValueFrom(
title$.pipe(filter((event): event is string => typeof event === 'string'))
);
}

expect(await testTitle(`"My title"`)).toEqual('My title');
expect(await testTitle(`'My title'`)).toEqual('My title');
expect(await testTitle(`"User's request for a title"`)).toEqual(`User's request for a title`);
});

it('mentions the given response language in the instruction', async () => {
const { chatSpy, title$ } = callGenerateTitle(
{
responseLanguage: 'Orcish',
},
[
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
}),
]
);

await lastValueFrom(title$);

const [, params] = chatSpy.mock.calls[0];
expect(params.messages[0].message.content).toContain('Orcish');
});

it('handles partial updates', async () => {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: '',
},
}),
createChatCompletionChunk({
function_call: {
name: '',
arguments: JSON.stringify({ title: 'My title' }),
},
}),
]);

const title = await lastValueFrom(title$);

expect(title).toEqual('My title');
});

it('ignores token count events and still passes them through', async () => {
const { title$ } = callGenerateTitle([
createChatCompletionChunk({
function_call: {
name: 'title_conversation',
arguments: JSON.stringify({ title: 'My title' }),
},
}),
{
type: StreamingChatResponseEventType.TokenCount,
tokens: {
completion: 10,
prompt: 10,
total: 10,
},
},
]);

const events = await lastValueFrom(title$.pipe(toArray()));

expect(events).toEqual([
'My title',
{
tokens: {
completion: 10,
prompt: 10,
total: 10,
},
type: StreamingChatResponseEventType.TokenCount,
},
]);
});

it('handles errors in chat and falls back to the default title', async () => {
const chatSpy = jest
.fn()
.mockImplementation(() => throwError(() => new Error('Error generating title')));

const logger = {
debug: jest.fn(),
error: jest.fn(),
};

const title$ = getGeneratedTitle({
chat: chatSpy,
logger,
messages,
});

const title = await lastValueFrom(title$);

expect(title).toEqual('New conversation');

expect(logger.error).toHaveBeenCalledWith('Error generating title');
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { catchError, map, Observable, of, tap } from 'rxjs';
import { catchError, last, map, Observable, of, tap } from 'rxjs';
import { Logger } from '@kbn/logging';
import type { ObservabilityAIAssistantClient } from '..';
import { Message, MessageRole } from '../../../../common';
Expand All @@ -30,7 +30,7 @@ export function getGeneratedTitle({
responseLanguage?: string;
messages: Message[];
chat: ChatFunctionWithoutConnectorAndTokenCount;
logger: Logger;
logger: Pick<Logger, 'debug' | 'error'>;
}): Observable<string | TokenCountEvent> {
return hideTokenCountEvents((hide) =>
chat('generate_title', {
Expand All @@ -46,9 +46,11 @@ export function getGeneratedTitle({
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: messages.slice(1).reduce((acc, curr) => {
return `${acc} ${curr.message.role}: ${curr.message.content}`;
}, 'Generate a title, using the title_conversation_function, based on the following conversation:\n\n'),
content: messages
.filter((msg) => msg.message.role !== MessageRole.System)
.reduce((acc, curr) => {
return `${acc} ${curr.message.role}: ${curr.message.content}`;
}, 'Generate a title, using the title_conversation_function, based on the following conversation:\n\n'),
},
},
],
Expand All @@ -72,21 +74,21 @@ export function getGeneratedTitle({
}).pipe(
hide(),
concatenateChatCompletionChunks(),
last(),
map((concatenatedMessage) => {
const input =
const title: string =
(concatenatedMessage.message.function_call.name
? JSON.parse(concatenatedMessage.message.function_call.arguments).title
: concatenatedMessage.message?.content) || '';

// This regular expression captures a string enclosed in single or double quotes.
// This captures a string enclosed in single or double quotes.
// It extracts the string content without the quotes.
// Example matches:
// - "Hello, World!" => Captures: Hello, World!
// - 'Another Example' => Captures: Another Example
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
const match = input.match(/^["']?([^"']+)["']?$/);
const title = match ? match[1] : input;
return title;

return title.replace(/^"(.*)"$/g, '$1').replace(/^'(.*)'$/g, '$1');
}),
tap((event) => {
if (typeof event === 'string') {
Expand Down

0 comments on commit 27fdb2d

Please sign in to comment.