Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate regenerate API #58

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions public/hooks/use_chat_actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,19 @@ export const useChatActions = (): AssistantActions => {
}
};

const regenerate = async () => {
const regenerate = async (interactionId: string) => {
if (chatContext.sessionId) {
const abortController = new AbortController();
abortControllerRef = abortController;
chatStateDispatch({ type: 'regenerate' });

try {
const response = await core.services.http.put(`${ASSISTANT_API.REGENERATE}`, {
body: JSON.stringify({ sessionId: chatContext.sessionId }),
body: JSON.stringify({
sessionId: chatContext.sessionId,
rootAgentId: chatContext.rootAgentId,
interactionId,
}),
});

if (abortController.signal.aborted) {
Expand Down
7 changes: 7 additions & 0 deletions public/tabs/chat/messages/message_bubble.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ describe('<MessageBubble />', () => {
contentType: 'markdown',
content: 'here are the indices in your cluster: .alert',
}}
interaction={{
input: 'foo',
response: 'bar',
conversation_id: 'foo',
interaction_id: 'bar',
create_time: new Date().toLocaleString(),
}}
/>
);
expect(screen.queryAllByTitle('regenerate message')).toHaveLength(1);
Expand Down
8 changes: 4 additions & 4 deletions public/tabs/chat/messages/message_bubble.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type MessageBubbleProps = {
showActionBar: boolean;
showRegenerate?: boolean;
shouldActionBarVisibleOnHover?: boolean;
onRegenerate?: () => void;
onRegenerate?: (interactionId: string) => void;
} & (
| {
message: IMessage;
Expand Down Expand Up @@ -192,17 +192,17 @@ export const MessageBubble: React.FC<MessageBubbleProps> = React.memo((props) =>
</EuiCopy>
</EuiFlexItem>
)}
{props.showRegenerate && (
{props.showRegenerate && props.interaction?.interaction_id ? (
<EuiFlexItem grow={false}>
<EuiButtonIcon
aria-label="regenerate message"
onClick={props.onRegenerate}
onClick={() => props.onRegenerate?.(props.interaction?.interaction_id || '')}
title="regenerate message"
color="text"
iconType="refresh"
/>
</EuiFlexItem>
)}
) : null}
{showFeedback && (
// After feedback, only corresponding thumb icon will be kept and disabled.
<>
Expand Down
2 changes: 1 addition & 1 deletion public/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export interface AssistantActions {
openChatUI: (sessionId?: string) => void;
executeAction: (suggestedAction: ISuggestedAction, message: IMessage) => void;
abortAction: (sessionId?: string) => void;
regenerate: () => void;
regenerate: (interactionId: string) => void;
}

export interface AppPluginStartDependencies {
Expand Down
49 changes: 23 additions & 26 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
} from '../../../../src/core/server';
import { ASSISTANT_API } from '../../common/constants/llm';
import { OllyChatService } from '../services/chat/olly_chat_service';
import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';
import { RoutesOptions } from '../types';
import { ChatService } from '../services/chat/chat_service';
Expand Down Expand Up @@ -64,6 +63,7 @@ const regenerateRoute = {
body: schema.object({
sessionId: schema.string(),
rootAgentId: schema.string(),
interactionId: schema.string(),
}),
},
};
Expand Down Expand Up @@ -314,42 +314,39 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { sessionId, rootAgentId } = request.body;
const { sessionId: sessionIdInRequestBody, rootAgentId, interactionId } = request.body;
const storageService = createStorageService(context);
let messages: IMessage[] = [];
const chatService = createChatService();

let outputs: Awaited<ReturnType<ChatService['regenerate']>> | undefined;

/**
* Get final answer from Agent framework
*/
try {
const session = await storageService.getSession(sessionId);
messages.push(...session.messages);
outputs = await chatService.regenerate(
{ sessionId: sessionIdInRequestBody, rootAgentId, interactionId },
context
);
} catch (error) {
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
context.assistant_plugin.logger.error(error);
}

const lastInputIndex = messages.findLastIndex((msg) => msg.type === 'input');
// Find last input message
const input = messages[lastInputIndex] as IInput;
// Take the messages before last input message as memory as regenerate will exclude the last outputs
messages = messages.slice(0, lastInputIndex);

/**
* Retrieve latest interactions from memory
*/
const sessionId = sessionIdInRequestBody;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename sessionIdInRequestBody to sessionId?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done for that.

try {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId, rootAgentId },
context
);
const title = input.content.substring(0, 50);
const saveMessagesResponse = await storageService.saveMessages(
title,
sessionId,
[...messages, input, ...outputs.messages].filter(
(message) => message.content !== 'AbortError'
)
);
const conversation = await storageService.getSession(sessionId);

return response.ok({
body: { ...saveMessagesResponse, title },
body: {
...conversation,
sessionId,
},
});
} catch (error) {
context.assistant_plugin.logger.warn(error);
context.assistant_plugin.logger.error(error);
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}
Expand Down
11 changes: 9 additions & 2 deletions server/services/chat/chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ import { LLMRequestSchema } from '../../routes/chat_routes';
export interface ChatService {
requestLLM(
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
context: RequestHandlerContext
): Promise<{
messages: IMessage[];
memoryId: string;
}>;

regenerate(
payload: { sessionId: string; interactionId: string; rootAgentId: string },
context: RequestHandlerContext
): Promise<{
messages: IMessage[];
memoryId: string;
Expand Down
176 changes: 176 additions & 0 deletions server/services/chat/olly_chat_service.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { OllyChatService } from './olly_chat_service';
import { CoreRouteHandlerContext } from '../../../../../src/core/server/core_route_handler_context';
import { coreMock, httpServerMock } from '../../../../../src/core/server/mocks';
import { loggerMock } from '../../../../../src/core/server/logging/logger.mock';

describe('OllyChatService', () => {
const ollyChatService = new OllyChatService();
const coreContext = new CoreRouteHandlerContext(
coreMock.createInternalStart(),
httpServerMock.createOpenSearchDashboardsRequest()
);
const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport
.request as jest.Mock;
const contextMock = {
core: coreContext,
assistant_plugin: {
logger: loggerMock.create(),
},
};
beforeEach(() => {
mockedTransport.mockClear();
});
it('requestLLM should invoke client call with correct params', async () => {
mockedTransport.mockImplementationOnce(() => {
return {
body: {
inference_results: [
{
output: [
{
name: 'memory_id',
result: 'foo',
},
],
},
],
},
};
});
const result = await ollyChatService.requestLLM(
{
messages: [],
input: {
type: 'input',
contentType: 'text',
content: 'content',
},
sessionId: '',
rootAgentId: 'rootAgentId',
},
contextMock
);
expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
"body": Object {
"parameters": Object {
"question": "content",
"verbose": true,
},
},
"method": "POST",
"path": "/_plugins/_ml/agents/rootAgentId/_execute",
},
Object {
"maxRetries": 0,
"requestTimeout": 300000,
},
],
]
`);
expect(result).toMatchInlineSnapshot(`
Object {
"memoryId": "foo",
"messages": Array [],
}
`);
});

it('requestLLM should throw error when transport.request throws error', async () => {
mockedTransport.mockImplementationOnce(() => {
throw new Error('error');
});
expect(
ollyChatService.requestLLM(
{
messages: [],
input: {
type: 'input',
contentType: 'text',
content: 'content',
},
sessionId: '',
rootAgentId: 'rootAgentId',
},
contextMock
)
).rejects.toMatchInlineSnapshot(`[Error: error]`);
});

it('regenerate should invoke client call with correct params', async () => {
mockedTransport.mockImplementationOnce(() => {
return {
body: {
inference_results: [
{
output: [
{
name: 'memory_id',
result: 'foo',
},
],
},
],
},
};
});
const result = await ollyChatService.regenerate(
{
sessionId: 'sessionId',
rootAgentId: 'rootAgentId',
interactionId: 'interactionId',
},
contextMock
);
expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
"body": Object {
"parameters": Object {
"memory_id": "sessionId",
"regenerate_interaction_id": "interactionId",
"verbose": true,
},
},
"method": "POST",
"path": "/_plugins/_ml/agents/rootAgentId/_execute",
},
Object {
"maxRetries": 0,
"requestTimeout": 300000,
},
],
]
`);
expect(result).toMatchInlineSnapshot(`
Object {
"memoryId": "foo",
"messages": Array [],
}
`);
});

it('regenerate should throw error when transport.request throws error', async () => {
mockedTransport.mockImplementationOnce(() => {
throw new Error('error');
});
expect(
ollyChatService.regenerate(
{
sessionId: 'sessionId',
rootAgentId: 'rootAgentId',
interactionId: 'interactionId',
},
contextMock
)
).rejects.toMatchInlineSnapshot(`[Error: error]`);
});
});
Loading
Loading