-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Obs AI Assistant] Fix auto-generation of titles (#182923)
If the LLM doesn't send the title_conversation in a single go, `getGeneratedTitle` fails, because it doesn't wait until the message has been fully completed. This PR fixes that issue, and adds tests, and also improves quote matching.
- Loading branch information
1 parent
71784ad
commit 27fdb2d
Showing
2 changed files
with
227 additions
and
10 deletions.
There are no files selected for viewing
215 changes: 215 additions & 0 deletions
215
...on/observability_ai_assistant/server/service/client/operators/get_generated_title.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
import { filter, lastValueFrom, of, throwError, toArray } from 'rxjs'; | ||
import { | ||
ChatCompletionChunkEvent, | ||
Message, | ||
MessageRole, | ||
StreamingChatResponseEventType, | ||
} from '../../../../common'; | ||
import { ChatEvent } from '../../../../common/conversation_complete'; | ||
import { getGeneratedTitle } from './get_generated_title'; | ||
|
||
describe('getGeneratedTitle', () => { | ||
const messages: Message[] = [ | ||
{ | ||
'@timestamp': new Date().toISOString(), | ||
message: { | ||
content: 'A message', | ||
role: MessageRole.User, | ||
}, | ||
}, | ||
]; | ||
|
||
function createChatCompletionChunk( | ||
content: string | { content?: string; function_call?: { name: string; arguments: string } } | ||
): ChatCompletionChunkEvent { | ||
const msg = typeof content === 'string' ? { content } : content; | ||
|
||
return { | ||
type: StreamingChatResponseEventType.ChatCompletionChunk, | ||
id: 'id', | ||
message: msg, | ||
}; | ||
} | ||
|
||
function callGenerateTitle( | ||
...rest: [ChatEvent[]] | [{ responseLanguage?: string }, ChatEvent[]] | ||
) { | ||
const options = rest.length === 1 ? {} : rest[0]; | ||
const chunks = rest.length === 1 ? rest[0] : rest[1]; | ||
|
||
const chatSpy = jest.fn().mockImplementation(() => of(...chunks)); | ||
|
||
const title$ = getGeneratedTitle({ | ||
chat: chatSpy, | ||
logger: { | ||
debug: jest.fn(), | ||
error: jest.fn(), | ||
}, | ||
messages, | ||
...options, | ||
}); | ||
|
||
return { chatSpy, title$ }; | ||
} | ||
|
||
it('returns the given title as a string', async () => { | ||
const { title$ } = callGenerateTitle([ | ||
createChatCompletionChunk({ | ||
function_call: { | ||
name: 'title_conversation', | ||
arguments: JSON.stringify({ title: 'My title' }), | ||
}, | ||
}), | ||
]); | ||
|
||
const title = await lastValueFrom( | ||
title$.pipe(filter((event): event is string => typeof event === 'string')) | ||
); | ||
|
||
expect(title).toEqual('My title'); | ||
}); | ||
|
||
it('calls chat with the user message', async () => { | ||
const { chatSpy, title$ } = callGenerateTitle([ | ||
createChatCompletionChunk({ | ||
function_call: { | ||
name: 'title_conversation', | ||
arguments: JSON.stringify({ title: 'My title' }), | ||
}, | ||
}), | ||
]); | ||
|
||
await lastValueFrom(title$); | ||
|
||
const [name, params] = chatSpy.mock.calls[0]; | ||
|
||
expect(name).toEqual('generate_title'); | ||
expect(params.messages.length).toBe(2); | ||
expect(params.messages[1].message.content).toContain('A message'); | ||
}); | ||
|
||
it('strips quotes from the title', async () => { | ||
async function testTitle(title: string) { | ||
const { title$ } = callGenerateTitle([ | ||
createChatCompletionChunk({ | ||
function_call: { | ||
name: 'title_conversation', | ||
arguments: JSON.stringify({ title }), | ||
}, | ||
}), | ||
]); | ||
|
||
return await lastValueFrom( | ||
title$.pipe(filter((event): event is string => typeof event === 'string')) | ||
); | ||
} | ||
|
||
expect(await testTitle(`"My title"`)).toEqual('My title'); | ||
expect(await testTitle(`'My title'`)).toEqual('My title'); | ||
expect(await testTitle(`"User's request for a title"`)).toEqual(`User's request for a title`); | ||
}); | ||
|
||
it('mentions the given response language in the instruction', async () => { | ||
const { chatSpy, title$ } = callGenerateTitle( | ||
{ | ||
responseLanguage: 'Orcish', | ||
}, | ||
[ | ||
createChatCompletionChunk({ | ||
function_call: { | ||
name: 'title_conversation', | ||
arguments: JSON.stringify({ title: 'My title' }), | ||
}, | ||
}), | ||
] | ||
); | ||
|
||
await lastValueFrom(title$); | ||
|
||
const [, params] = chatSpy.mock.calls[0]; | ||
expect(params.messages[0].message.content).toContain('Orcish'); | ||
}); | ||
|
||
it('handles partial updates', async () => { | ||
const { title$ } = callGenerateTitle([ | ||
createChatCompletionChunk({ | ||
function_call: { | ||
name: 'title_conversation', | ||
arguments: '', | ||
}, | ||
}), | ||
createChatCompletionChunk({ | ||
function_call: { | ||
name: '', | ||
arguments: JSON.stringify({ title: 'My title' }), | ||
}, | ||
}), | ||
]); | ||
|
||
const title = await lastValueFrom(title$); | ||
|
||
expect(title).toEqual('My title'); | ||
}); | ||
|
||
it('ignores token count events and still passes them through', async () => { | ||
const { title$ } = callGenerateTitle([ | ||
createChatCompletionChunk({ | ||
function_call: { | ||
name: 'title_conversation', | ||
arguments: JSON.stringify({ title: 'My title' }), | ||
}, | ||
}), | ||
{ | ||
type: StreamingChatResponseEventType.TokenCount, | ||
tokens: { | ||
completion: 10, | ||
prompt: 10, | ||
total: 10, | ||
}, | ||
}, | ||
]); | ||
|
||
const events = await lastValueFrom(title$.pipe(toArray())); | ||
|
||
expect(events).toEqual([ | ||
'My title', | ||
{ | ||
tokens: { | ||
completion: 10, | ||
prompt: 10, | ||
total: 10, | ||
}, | ||
type: StreamingChatResponseEventType.TokenCount, | ||
}, | ||
]); | ||
}); | ||
|
||
it('handles errors in chat and falls back to the default title', async () => { | ||
const chatSpy = jest | ||
.fn() | ||
.mockImplementation(() => throwError(() => new Error('Error generating title'))); | ||
|
||
const logger = { | ||
debug: jest.fn(), | ||
error: jest.fn(), | ||
}; | ||
|
||
const title$ = getGeneratedTitle({ | ||
chat: chatSpy, | ||
logger, | ||
messages, | ||
}); | ||
|
||
const title = await lastValueFrom(title$); | ||
|
||
expect(title).toEqual('New conversation'); | ||
|
||
expect(logger.error).toHaveBeenCalledWith('Error generating title'); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters