From 53971bbce96abf5d08180b2edfc37a406e01f533 Mon Sep 17 00:00:00 2001 From: SuZhou-Joe Date: Mon, 20 Nov 2023 10:58:00 +0800 Subject: [PATCH] feat: integrate with memory APIs Signed-off-by: SuZhou-Joe --- common/types/chat_saved_object_attributes.ts | 4 +- server/routes/chat_routes.ts | 31 ++-- server/services/chat/chat_service.ts | 5 +- server/services/chat/olly_chat_service.ts | 101 ++++++------- .../agent_framework_storage_service.ts | 142 ++++++++++++++++++ 5 files changed, 204 insertions(+), 79 deletions(-) create mode 100644 server/services/storage/agent_framework_storage_service.ts diff --git a/common/types/chat_saved_object_attributes.ts b/common/types/chat_saved_object_attributes.ts index 341ca795..5eab4d8e 100644 --- a/common/types/chat_saved_object_attributes.ts +++ b/common/types/chat_saved_object_attributes.ts @@ -19,7 +19,7 @@ export interface ISessionFindResponse { total: number; } -interface IInput { +export interface IInput { type: 'input'; contentType: 'text'; content: string; @@ -27,7 +27,7 @@ interface IInput { appId?: string; }; } -interface IOutput { +export interface IOutput { type: 'output'; traceId?: string; // used for tracing agent calls toolsUsed?: string[]; diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts index 079181c2..8ec4498f 100644 --- a/server/routes/chat_routes.ts +++ b/server/routes/chat_routes.ts @@ -13,7 +13,7 @@ import { } from '../../../../src/core/server'; import { ASSISTANT_API } from '../../common/constants/llm'; import { OllyChatService } from '../services/chat/olly_chat_service'; -import { SavedObjectsStorageService } from '../services/storage/saved_objects_storage_service'; +import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service'; const llmRequestRoute = { path: ASSISTANT_API.SEND_MESSAGE, @@ -60,7 +60,7 @@ export type GetSessionsSchema = TypeOf; export function registerChatRoutes(router: IRouter) { const createStorageService = (context: RequestHandlerContext) => - new SavedObjectsStorageService(context.core.savedObjects.client); + new AgentFrameworkStorageService(context.core.opensearch.client.asCurrentUser); const createChatService = () => new OllyChatService(); router.post( @@ -70,28 +70,21 @@ export function registerChatRoutes(router: IRouter) { request, response ): Promise> => { - const { sessionId, input, messages = [] } = request.body; + const { messages = [] } = request.body; const storageService = createStorageService(context); const chatService = createChatService(); - // get history from the chat object for existing chats - if (sessionId && messages.length === 0) { - try { - const session = await storageService.getSession(sessionId); - messages.push(...session.messages); - } catch (error) { - return response.custom({ statusCode: error.statusCode || 500, body: error.message }); - } - } - try { const outputs = await chatService.requestLLM(messages, context, request); - const saveMessagesResponse = await storageService.saveMessages( - input.content.substring(0, 50), - sessionId, - [...messages, input, ...outputs] - ); - return response.ok({ body: saveMessagesResponse }); + const sessionId = outputs.memoryId; + const finalMessage = await storageService.getSession(sessionId); + + return response.ok({ + body: { + messages: finalMessage.messages, + sessionId: outputs.memoryId, + }, + }); } catch (error) { context.assistant_plugin.logger.warn(error); return response.custom({ statusCode: error.statusCode || 500, body: error.message }); diff --git a/server/services/chat/chat_service.ts b/server/services/chat/chat_service.ts index cb2da552..c65313ad 100644 --- a/server/services/chat/chat_service.ts +++ b/server/services/chat/chat_service.ts @@ -13,7 +13,10 @@ export interface ChatService { messages: IMessage[], context: RequestHandlerContext, request: OpenSearchDashboardsRequest - ): Promise; + ): Promise<{ + messages: IMessage[]; + memoryId: string; + }>; generatePPL( context: RequestHandlerContext, request: OpenSearchDashboardsRequest diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts index 417e1525..0525e804 100644 --- a/server/services/chat/olly_chat_service.ts +++ b/server/services/chat/olly_chat_service.ts @@ -9,96 +9,83 @@ import { ApiResponse } from '@opensearch-project/opensearch'; import { OpenSearchDashboardsRequest, RequestHandlerContext } from '../../../../../src/core/server'; import { IMessage } from '../../../common/types/chat_saved_object_attributes'; import { convertToTraces } from '../../../common/utils/llm_chat/traces'; -import { chatAgentInit } from '../../olly/agents/agent_helpers'; import { OpenSearchTracer } from '../../olly/callbacks/opensearch_tracer'; -import { requestSuggestionsChain } from '../../olly/chains/suggestions_generator'; -import { memoryInit } from '../../olly/memory/chat_agent_memory'; import { LLMModelFactory } from '../../olly/models/llm_model_factory'; -import { initTools } from '../../olly/tools/tools_helper'; import { PPLTools } from '../../olly/tools/tool_sets/ppl'; import { buildOutputs } from '../../olly/utils/output_builders/build_outputs'; import { LLMRequestSchema } from '../../routes/chat_routes'; import { PPLGenerationRequestSchema } from '../../routes/langchain_routes'; import { ChatService } from './chat_service'; +const MEMORY_ID_FIELD = 'memory_id'; +const RESPONSE_FIELD = 'response'; + export class OllyChatService implements ChatService { public async requestLLM( messages: IMessage[], context: RequestHandlerContext, request: OpenSearchDashboardsRequest - ): Promise { - const { input } = request.body; - const traceId = uuid(); - const observabilityClient = context.assistant_plugin.observabilityClient.asScoped(request); + ): Promise<{ + messages: IMessage[]; + memoryId: string; + }> { + const { input, sessionId } = request.body; const opensearchClient = context.core.opensearch.client.asCurrentUser; - const savedObjectsClient = context.core.savedObjects.client; try { const runs: Run[] = []; - const callbacks = [new OpenSearchTracer(opensearchClient, traceId, runs)]; - const model = LLMModelFactory.createModel({ client: opensearchClient }); - const embeddings = LLMModelFactory.createEmbeddings({ client: opensearchClient }); - const pluginTools = initTools( - model, - embeddings, - opensearchClient, - observabilityClient, - savedObjectsClient, - callbacks - ); - const memory = memoryInit(messages); /** * Wait for an API to fetch root agent id. */ + const parametersPayload: { + question: string; + verbose?: boolean; + memory_id?: string; + } = { + question: input.content, + verbose: true, + }; + if (sessionId) { + parametersPayload.memory_id = sessionId; + } const agentFrameworkResponse = (await opensearchClient.transport.request({ method: 'POST', - path: '/_plugins/_ml/agents/_UoprosBZFp32K9Rsfqe/_execute', + path: '/_plugins/_ml/agents/-jld3IsBXlmiPBu-5dDC/_execute', body: { - parameters: { - question: input.content, - }, + parameters: parametersPayload, }, })) as ApiResponse<{ inference_results: Array<{ - output: Array<{ name: string; result?: string; dataAsMap?: { response: string } }>; + output: Array<{ name: string; result?: string }>; }>; }>; - const outputBody = agentFrameworkResponse.body.inference_results?.[0]?.output?.[0]; - const agentFrameworkAnswer = - outputBody?.dataAsMap?.response || (outputBody?.result as string); + const outputBody = + agentFrameworkResponse.body.inference_results?.[0]?.output || + agentFrameworkResponse.body.inference_results?.[0]?.output; + const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD); + const reversedOutputBody = [...outputBody].reverse(); + const finalAnswerItem = reversedOutputBody.find((item) => item.name === RESPONSE_FIELD); - /** - * Append history manually as suggestion requires latest history. - * Please delete following line after memory API is ready - */ - await memory.chatHistory.addUserMessage(input.content); - await memory.chatHistory.addAIChatMessage(agentFrameworkAnswer); - - const suggestions = await requestSuggestionsChain( - model, - pluginTools.flatMap((tool) => tool.toolsList), - memory, - callbacks - ); + const agentFrameworkAnswer = finalAnswerItem?.result || ''; - return buildOutputs( - input.content, - agentFrameworkAnswer, - traceId, - suggestions, - convertToTraces(runs) - ); + return { + messages: buildOutputs(input.content, agentFrameworkAnswer, '', {}, convertToTraces(runs)), + memoryId: memoryIdItem?.result || '', + }; } catch (error) { context.assistant_plugin.logger.error(error); - return [ - { - type: 'output', - traceId, - contentType: 'error', - content: error.message, - }, - ]; + return { + messages: [ + { + type: 'output', + traceId: '', + contentType: 'error', + content: error.message, + }, + ], + memoryId: '', + }; } } diff --git a/server/services/storage/agent_framework_storage_service.ts b/server/services/storage/agent_framework_storage_service.ts new file mode 100644 index 00000000..ea80197c --- /dev/null +++ b/server/services/storage/agent_framework_storage_service.ts @@ -0,0 +1,142 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ApiResponse } from '@opensearch-project/opensearch/.'; +import { OpenSearchClient } from '../../../../../src/core/server'; +import { LLM_INDEX } from '../../../common/constants/llm'; +import { + IInput, + IMessage, + IOutput, + ISession, + ISessionFindResponse, +} from '../../../common/types/chat_saved_object_attributes'; +import { GetSessionsSchema } from '../../routes/chat_routes'; +import { StorageService } from './storage_service'; + +export class AgentFrameworkStorageService implements StorageService { + constructor(private readonly client: OpenSearchClient) {} + async getSession(sessionId: string): Promise { + const session = (await this.client.transport.request({ + method: 'GET', + path: `/_plugins/_ml/memory/conversation/${sessionId}`, + })) as ApiResponse<{ + interactions: Array<{ + input: string; + response: string; + parent_interaction_id: string; + interaction_id: string; + }>; + }>; + return { + title: 'test', + version: 1, + createdTimeMs: Date.now(), + updatedTimeMs: Date.now(), + messages: session.body.interactions + .filter((item) => !item.parent_interaction_id) + .reduce((total, current) => { + const inputItem: IInput = { + type: 'input', + contentType: 'text', + content: current.input, + }; + const outputItems: IOutput[] = [ + { + type: 'output', + contentType: 'markdown', + content: current.response, + traceId: current.interaction_id, + }, + ]; + return [...total, inputItem, ...outputItems]; + }, [] as IMessage[]), + }; + } + + async getSessions(query: GetSessionsSchema): Promise { + await this.createIndex(); + const sessions = await this.client.search({ + index: LLM_INDEX.SESSIONS, + body: { + from: (query.page - 1) * query.perPage, + size: query.perPage, + ...(query.sortField && + query.sortOrder && { sort: [{ [query.sortField]: query.sortOrder }] }), + }, + }); + + return { + objects: sessions.body.hits.hits + .filter( + (hit): hit is RequiredKey => + hit._source !== null && hit._source !== undefined + ) + .map((session) => ({ ...session._source, id: session._id })), + total: + typeof sessions.body.hits.total === 'number' + ? sessions.body.hits.total + : sessions.body.hits.total.value, + }; + } + + async saveMessages( + title: string, + sessionId: string | undefined, + messages: IMessage[] + ): Promise<{ sessionId: string; messages: IMessage[] }> { + await this.createIndex(); + const timestamp = new Date().getTime(); + if (!sessionId) { + const createResponse = await this.client.index({ + index: LLM_INDEX.SESSIONS, + body: { + title, + version: 1, + createdTimeMs: timestamp, + updatedTimeMs: timestamp, + messages, + }, + }); + return { sessionId: createResponse.body._id, messages }; + } + const updateResponse = await this.client.update>({ + index: LLM_INDEX.SESSIONS, + id: sessionId, + body: { + doc: { + messages, + updatedTimeMs: timestamp, + }, + }, + }); + return { sessionId, messages }; + } + + private async createIndex() { + const existsResponse = await this.client.indices.exists({ index: LLM_INDEX.SESSIONS }); + if (!existsResponse.body) { + return this.client.indices.create({ + index: LLM_INDEX.SESSIONS, + body: { + settings: { + index: { + number_of_shards: '1', + auto_expand_replicas: '0-2', + mapping: { ignore_malformed: true }, + }, + }, + mappings: { + properties: { + title: { type: 'keyword' }, + createdTimeMs: { type: 'date' }, + updatedTimeMs: { type: 'date' }, + }, + }, + }, + }); + } + } +}