Skip to content

Commit

Permalink
feat: integrate with memory APIs
Browse files Browse the repository at this point in the history
Signed-off-by: SuZhou-Joe <[email protected]>
  • Loading branch information
SuZhou-Joe committed Nov 20, 2023
1 parent 52753f4 commit 53971bb
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 79 deletions.
4 changes: 2 additions & 2 deletions common/types/chat_saved_object_attributes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ export interface ISessionFindResponse {
total: number;
}

interface IInput {
export interface IInput {
type: 'input';
contentType: 'text';
content: string;
context?: {
appId?: string;
};
}
interface IOutput {
export interface IOutput {
type: 'output';
traceId?: string; // used for tracing agent calls
toolsUsed?: string[];
Expand Down
31 changes: 12 additions & 19 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
} from '../../../../src/core/server';
import { ASSISTANT_API } from '../../common/constants/llm';
import { OllyChatService } from '../services/chat/olly_chat_service';
import { SavedObjectsStorageService } from '../services/storage/saved_objects_storage_service';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';

const llmRequestRoute = {
path: ASSISTANT_API.SEND_MESSAGE,
Expand Down Expand Up @@ -60,7 +60,7 @@ export type GetSessionsSchema = TypeOf<typeof getSessionsRoute.validate.query>;

export function registerChatRoutes(router: IRouter) {
const createStorageService = (context: RequestHandlerContext) =>
new SavedObjectsStorageService(context.core.savedObjects.client);
new AgentFrameworkStorageService(context.core.opensearch.client.asCurrentUser);
const createChatService = () => new OllyChatService();

router.post(
Expand All @@ -70,28 +70,21 @@ export function registerChatRoutes(router: IRouter) {
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { sessionId, input, messages = [] } = request.body;
const { messages = [] } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();

// get history from the chat object for existing chats
if (sessionId && messages.length === 0) {
try {
const session = await storageService.getSession(sessionId);
messages.push(...session.messages);
} catch (error) {
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}

try {
const outputs = await chatService.requestLLM(messages, context, request);
const saveMessagesResponse = await storageService.saveMessages(
input.content.substring(0, 50),
sessionId,
[...messages, input, ...outputs]
);
return response.ok({ body: saveMessagesResponse });
const sessionId = outputs.memoryId;
const finalMessage = await storageService.getSession(sessionId);

return response.ok({
body: {
messages: finalMessage.messages,
sessionId: outputs.memoryId,
},
});
} catch (error) {
context.assistant_plugin.logger.warn(error);
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
Expand Down
5 changes: 4 additions & 1 deletion server/services/chat/chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ export interface ChatService {
messages: IMessage[],
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
): Promise<IMessage[]>;
): Promise<{
messages: IMessage[];
memoryId: string;
}>;
generatePPL(
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, PPLGenerationRequestSchema, 'post'>
Expand Down
101 changes: 44 additions & 57 deletions server/services/chat/olly_chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,96 +9,83 @@ import { ApiResponse } from '@opensearch-project/opensearch';
import { OpenSearchDashboardsRequest, RequestHandlerContext } from '../../../../../src/core/server';
import { IMessage } from '../../../common/types/chat_saved_object_attributes';
import { convertToTraces } from '../../../common/utils/llm_chat/traces';
import { chatAgentInit } from '../../olly/agents/agent_helpers';
import { OpenSearchTracer } from '../../olly/callbacks/opensearch_tracer';
import { requestSuggestionsChain } from '../../olly/chains/suggestions_generator';
import { memoryInit } from '../../olly/memory/chat_agent_memory';
import { LLMModelFactory } from '../../olly/models/llm_model_factory';
import { initTools } from '../../olly/tools/tools_helper';
import { PPLTools } from '../../olly/tools/tool_sets/ppl';
import { buildOutputs } from '../../olly/utils/output_builders/build_outputs';
import { LLMRequestSchema } from '../../routes/chat_routes';
import { PPLGenerationRequestSchema } from '../../routes/langchain_routes';
import { ChatService } from './chat_service';

const MEMORY_ID_FIELD = 'memory_id';
const RESPONSE_FIELD = 'response';

export class OllyChatService implements ChatService {
public async requestLLM(
messages: IMessage[],
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
): Promise<IMessage[]> {
const { input } = request.body;
const traceId = uuid();
const observabilityClient = context.assistant_plugin.observabilityClient.asScoped(request);
): Promise<{
messages: IMessage[];
memoryId: string;
}> {
const { input, sessionId } = request.body;
const opensearchClient = context.core.opensearch.client.asCurrentUser;
const savedObjectsClient = context.core.savedObjects.client;

try {
const runs: Run[] = [];
const callbacks = [new OpenSearchTracer(opensearchClient, traceId, runs)];
const model = LLMModelFactory.createModel({ client: opensearchClient });
const embeddings = LLMModelFactory.createEmbeddings({ client: opensearchClient });
const pluginTools = initTools(
model,
embeddings,
opensearchClient,
observabilityClient,
savedObjectsClient,
callbacks
);
const memory = memoryInit(messages);

/**
* Wait for an API to fetch root agent id.
*/
const parametersPayload: {
question: string;
verbose?: boolean;
memory_id?: string;
} = {
question: input.content,
verbose: true,
};
if (sessionId) {
parametersPayload.memory_id = sessionId;
}
const agentFrameworkResponse = (await opensearchClient.transport.request({
method: 'POST',
path: '/_plugins/_ml/agents/_UoprosBZFp32K9Rsfqe/_execute',
path: '/_plugins/_ml/agents/-jld3IsBXlmiPBu-5dDC/_execute',
body: {
parameters: {
question: input.content,
},
parameters: parametersPayload,
},
})) as ApiResponse<{
inference_results: Array<{
output: Array<{ name: string; result?: string; dataAsMap?: { response: string } }>;
output: Array<{ name: string; result?: string }>;
}>;
}>;
const outputBody = agentFrameworkResponse.body.inference_results?.[0]?.output?.[0];
const agentFrameworkAnswer =
outputBody?.dataAsMap?.response || (outputBody?.result as string);
const outputBody =
agentFrameworkResponse.body.inference_results?.[0]?.output ||
agentFrameworkResponse.body.inference_results?.[0]?.output;
const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD);
const reversedOutputBody = [...outputBody].reverse();
const finalAnswerItem = reversedOutputBody.find((item) => item.name === RESPONSE_FIELD);

/**
* Append history manually as suggestion requires latest history.
* Please delete following line after memory API is ready
*/
await memory.chatHistory.addUserMessage(input.content);
await memory.chatHistory.addAIChatMessage(agentFrameworkAnswer);

const suggestions = await requestSuggestionsChain(
model,
pluginTools.flatMap((tool) => tool.toolsList),
memory,
callbacks
);
const agentFrameworkAnswer = finalAnswerItem?.result || '';

return buildOutputs(
input.content,
agentFrameworkAnswer,
traceId,
suggestions,
convertToTraces(runs)
);
return {
messages: buildOutputs(input.content, agentFrameworkAnswer, '', {}, convertToTraces(runs)),
memoryId: memoryIdItem?.result || '',
};
} catch (error) {
context.assistant_plugin.logger.error(error);
return [
{
type: 'output',
traceId,
contentType: 'error',
content: error.message,
},
];
return {
messages: [
{
type: 'output',
traceId: '',
contentType: 'error',
content: error.message,
},
],
memoryId: '',
};
}
}

Expand Down
142 changes: 142 additions & 0 deletions server/services/storage/agent_framework_storage_service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { OpenSearchClient } from '../../../../../src/core/server';
import { LLM_INDEX } from '../../../common/constants/llm';
import {
IInput,
IMessage,
IOutput,
ISession,
ISessionFindResponse,
} from '../../../common/types/chat_saved_object_attributes';
import { GetSessionsSchema } from '../../routes/chat_routes';
import { StorageService } from './storage_service';

export class AgentFrameworkStorageService implements StorageService {
constructor(private readonly client: OpenSearchClient) {}
async getSession(sessionId: string): Promise<ISession> {
const session = (await this.client.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/conversation/${sessionId}`,
})) as ApiResponse<{
interactions: Array<{
input: string;
response: string;
parent_interaction_id: string;
interaction_id: string;
}>;
}>;
return {
title: 'test',
version: 1,
createdTimeMs: Date.now(),
updatedTimeMs: Date.now(),
messages: session.body.interactions
.filter((item) => !item.parent_interaction_id)
.reduce((total, current) => {
const inputItem: IInput = {
type: 'input',
contentType: 'text',
content: current.input,
};
const outputItems: IOutput[] = [
{
type: 'output',
contentType: 'markdown',
content: current.response,
traceId: current.interaction_id,
},
];
return [...total, inputItem, ...outputItems];
}, [] as IMessage[]),
};
}

async getSessions(query: GetSessionsSchema): Promise<ISessionFindResponse> {
await this.createIndex();
const sessions = await this.client.search<ISession>({
index: LLM_INDEX.SESSIONS,
body: {
from: (query.page - 1) * query.perPage,
size: query.perPage,
...(query.sortField &&
query.sortOrder && { sort: [{ [query.sortField]: query.sortOrder }] }),
},
});

return {
objects: sessions.body.hits.hits
.filter(
(hit): hit is RequiredKey<typeof hit, '_source'> =>
hit._source !== null && hit._source !== undefined
)
.map((session) => ({ ...session._source, id: session._id })),
total:
typeof sessions.body.hits.total === 'number'
? sessions.body.hits.total
: sessions.body.hits.total.value,
};
}

async saveMessages(
title: string,
sessionId: string | undefined,
messages: IMessage[]
): Promise<{ sessionId: string; messages: IMessage[] }> {
await this.createIndex();
const timestamp = new Date().getTime();
if (!sessionId) {
const createResponse = await this.client.index<ISession>({
index: LLM_INDEX.SESSIONS,
body: {
title,
version: 1,
createdTimeMs: timestamp,
updatedTimeMs: timestamp,
messages,
},
});
return { sessionId: createResponse.body._id, messages };
}
const updateResponse = await this.client.update<Partial<ISession>>({
index: LLM_INDEX.SESSIONS,
id: sessionId,
body: {
doc: {
messages,
updatedTimeMs: timestamp,
},
},
});
return { sessionId, messages };
}

private async createIndex() {
const existsResponse = await this.client.indices.exists({ index: LLM_INDEX.SESSIONS });
if (!existsResponse.body) {
return this.client.indices.create({
index: LLM_INDEX.SESSIONS,
body: {
settings: {
index: {
number_of_shards: '1',
auto_expand_replicas: '0-2',
mapping: { ignore_malformed: true },
},
},
mappings: {
properties: {
title: { type: 'keyword' },
createdTimeMs: { type: 'date' },
updatedTimeMs: { type: 'date' },
},
},
},
});
}
}
}

0 comments on commit 53971bb

Please sign in to comment.