Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use agent framework API to generate answer #2

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions babel.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ module.exports = function (api) {
],
plugins: [
[require('@babel/plugin-transform-runtime'), { regenerator: true }],
require('@babel/plugin-proposal-class-properties'),
require('@babel/plugin-proposal-object-rest-spread'),
require('@babel/plugin-transform-class-properties'),
require('@babel/plugin-transform-object-rest-spread'),
[require('@babel/plugin-transform-modules-commonjs'), { allowTopLevelThis: true }],
],
};
Expand Down
4 changes: 4 additions & 0 deletions public/chat_header_button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ export const HeaderChatButton: React.FC<HeaderChatButtonProps> = (props) => {
const [traceId, setTraceId] = useState<string | undefined>(undefined);
const [chatSize, setChatSize] = useState<number | 'fullscreen' | 'dock-right'>('dock-right');
const flyoutFullScreen = chatSize === 'fullscreen';
const [rootAgentId, setRootAgentId] = useState<string>(
new URL(window.location.href).searchParams.get('agent_id') || ''
);

if (!flyoutLoaded && flyoutVisible) flyoutLoaded = true;

Expand Down Expand Up @@ -73,6 +76,7 @@ export const HeaderChatButton: React.FC<HeaderChatButtonProps> = (props) => {
setTitle,
traceId,
setTraceId,
rootAgentId,
}),
[
appId,
Expand Down
1 change: 1 addition & 0 deletions public/contexts/chat_context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export interface IChatContext {
setTitle: React.Dispatch<React.SetStateAction<string | undefined>>;
traceId?: string;
setTraceId: React.Dispatch<React.SetStateAction<string | undefined>>;
rootAgentId?: string;
}
export const ChatContext = React.createContext<IChatContext | null>(null);

Expand Down
1 change: 1 addition & 0 deletions public/hooks/use_chat_actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export const useChatActions = (): AssistantActions => {
// do not send abort signal to http client to allow LLM call run in background
body: JSON.stringify({
sessionId: chatContext.sessionId,
rootAgentId: chatContext.rootAgentId,
...(!chatContext.sessionId && { messages: chatState.messages }), // include all previous messages for new chats
input,
}),
Expand Down
37 changes: 15 additions & 22 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ import { ASSISTANT_API } from '../../common/constants/llm';
import { OllyChatService } from '../services/chat/olly_chat_service';
import { SavedObjectsStorageService } from '../services/storage/saved_objects_storage_service';
import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';

const llmRequestRoute = {
path: ASSISTANT_API.SEND_MESSAGE,
validate: {
body: schema.object({
sessionId: schema.maybe(schema.string()),
messages: schema.maybe(schema.arrayOf(schema.any())),
rootAgentId: schema.string(),
input: schema.object({
type: schema.literal('input'),
context: schema.object({
Expand Down Expand Up @@ -104,7 +106,7 @@ const updateSessionRoute = {

export function registerChatRoutes(router: IRouter) {
const createStorageService = (context: RequestHandlerContext) =>
new SavedObjectsStorageService(context.core.savedObjects.client);
new AgentFrameworkStorageService(context.core.opensearch.client.asCurrentUser);
const createChatService = () => new OllyChatService();

router.post(
Expand All @@ -114,34 +116,25 @@ export function registerChatRoutes(router: IRouter) {
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { sessionId, input, messages = [] } = request.body;
const { messages = [], input, sessionId: sessionIdInRequestBody } = request.body;
const storageService = createStorageService(context);
const chatService = createChatService();

// get history from the chat object for existing chats
if (sessionId && messages.length === 0) {
try {
const session = await storageService.getSession(sessionId);
messages.push(...session.messages);
} catch (error) {
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}

try {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId },
{ messages, input, sessionId: sessionIdInRequestBody },
context,
request
);
const title = input.content.substring(0, 50);
const saveMessagesResponse = await storageService.saveMessages(
title,
sessionId,
[...messages, input, ...outputs].filter((message) => message.content !== 'AbortError')
);
const sessionId = outputs.memoryId;
const finalMessage = await storageService.getSession(sessionId);

return response.ok({
body: { ...saveMessagesResponse, title },
body: {
messages: finalMessage.messages,
sessionId: outputs.memoryId,
title: finalMessage.title
},
});
} catch (error) {
context.assistant_plugin.logger.warn(error);
Expand Down Expand Up @@ -278,13 +271,13 @@ export function registerChatRoutes(router: IRouter) {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId },
context,
request
request as any
);
const title = input.content.substring(0, 50);
const saveMessagesResponse = await storageService.saveMessages(
title,
sessionId,
[...messages, input, ...outputs].filter((message) => message.content !== 'AbortError')
[...messages, input, ...outputs.messages].filter((message) => message.content !== 'AbortError')
);
return response.ok({
body: { ...saveMessagesResponse, title },
Expand Down
5 changes: 4 additions & 1 deletion server/services/chat/chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ export interface ChatService {
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
): Promise<IMessage[]>;
): Promise<{
messages: IMessage[];
memoryId: string;
}>;
generatePPL(
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, PPLGenerationRequestSchema, 'post'>
Expand Down
110 changes: 58 additions & 52 deletions server/services/chat/olly_chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,87 +5,93 @@

import { Run } from 'langchain/callbacks';
import { v4 as uuid } from 'uuid';
import { ApiResponse } from '@opensearch-project/opensearch';
import { OpenSearchDashboardsRequest, RequestHandlerContext } from '../../../../../src/core/server';
import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes';
import { convertToTraces } from '../../../common/utils/llm_chat/traces';
import { chatAgentInit } from '../../olly/agents/agent_helpers';
import { OpenSearchTracer } from '../../olly/callbacks/opensearch_tracer';
import { requestSuggestionsChain } from '../../olly/chains/suggestions_generator';
import { memoryInit } from '../../olly/memory/chat_agent_memory';
import { LLMModelFactory } from '../../olly/models/llm_model_factory';
import { initTools } from '../../olly/tools/tools_helper';
import { PPLTools } from '../../olly/tools/tool_sets/ppl';
import { buildOutputs } from '../../olly/utils/output_builders/build_outputs';
import { AbortAgentExecutionSchema, LLMRequestSchema } from '../../routes/chat_routes';
import { PPLGenerationRequestSchema } from '../../routes/langchain_routes';
import { ChatService } from './chat_service';

const MEMORY_ID_FIELD = 'memory_id';
const RESPONSE_FIELD = 'response';

export class OllyChatService implements ChatService {
static abortControllers: Map<string, AbortController> = new Map();

public async requestLLM(
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest
): Promise<IMessage[]> {
const traceId = uuid();
const observabilityClient = context.assistant_plugin.observabilityClient.asScoped(request);
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
): Promise<{
messages: IMessage[];
memoryId: string;
}> {
const { input, sessionId, rootAgentId } = request.body;
const opensearchClient = context.core.opensearch.client.asCurrentUser;
const savedObjectsClient = context.core.savedObjects.client;

if (payload.sessionId) {
OllyChatService.abortControllers.set(payload.sessionId, new AbortController());
}

try {
const runs: Run[] = [];
const callbacks = [new OpenSearchTracer(opensearchClient, traceId, runs)];
const model = LLMModelFactory.createModel({ client: opensearchClient });
const embeddings = LLMModelFactory.createEmbeddings({ client: opensearchClient });
const pluginTools = initTools(
model,
embeddings,
opensearchClient,
observabilityClient,
savedObjectsClient,
callbacks
);
const memory = memoryInit(payload.messages);
const chatAgent = chatAgentInit(
model,
pluginTools.flatMap((tool) => tool.toolsList),
callbacks,
memory
);
const agentResponse = await chatAgent.run(
payload.input.content,
payload.sessionId ? OllyChatService.abortControllers.get(payload.sessionId) : undefined
);

const suggestions = await requestSuggestionsChain(
model,
pluginTools.flatMap((tool) => tool.toolsList),
memory,
callbacks
);
/**
* Wait for an API to fetch root agent id.
*/
const parametersPayload: {
question: string;
verbose?: boolean;
memory_id?: string;
} = {
question: input.content,
verbose: true,
};
if (sessionId) {
parametersPayload.memory_id = sessionId;
}
const agentFrameworkResponse = (await opensearchClient.transport.request({
method: 'POST',
path: `/_plugins/_ml/agents/${rootAgentId}/_execute`,
body: {
parameters: parametersPayload,
},
})) as ApiResponse<{
inference_results: Array<{
output: Array<{ name: string; result?: string }>;
}>;
}>;
const outputBody =
agentFrameworkResponse.body.inference_results?.[0]?.output ||
agentFrameworkResponse.body.inference_results?.[0]?.output;
const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD);
const reversedOutputBody = [...outputBody].reverse();
const finalAnswerItem = reversedOutputBody.find((item) => item.name === RESPONSE_FIELD);

return buildOutputs(
payload.input.content,
agentResponse,
traceId,
suggestions,
convertToTraces(runs)
);
const agentFrameworkAnswer = finalAnswerItem?.result || '';

return {
messages: buildOutputs(input.content, agentFrameworkAnswer, '', {}, convertToTraces(runs)),
memoryId: memoryIdItem?.result || '',
};
} catch (error) {
context.assistant_plugin.logger.error(error);
return [
{
type: 'output',
traceId,
contentType: 'error',
content: error.message,
},
];
return {
messages: [
{
type: 'output',
traceId: '',
contentType: 'error',
content: error.message,
},
],
memoryId: '',
};
} finally {
if (payload.sessionId) {
OllyChatService.abortControllers.delete(payload.sessionId);
Expand Down
76 changes: 76 additions & 0 deletions server/services/storage/agent_framework_storage_service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { OpenSearchClient } from '../../../../../src/core/server';
import { LLM_INDEX } from '../../../common/constants/llm';
import {
IInput,
IMessage,
IOutput,
ISession,
ISessionFindResponse,
} from '../../../common/types/chat_saved_object_attributes';
import { GetSessionsSchema } from '../../routes/chat_routes';
import { StorageService } from './storage_service';

export class AgentFrameworkStorageService implements StorageService {
constructor(private readonly client: OpenSearchClient) {}
async getSession(sessionId: string): Promise<ISession> {
const session = (await this.client.transport.request({
method: 'GET',
path: `/_plugins/_ml/memory/conversation/${sessionId}`,
})) as ApiResponse<{
interactions: Array<{
input: string;
response: string;
parent_interaction_id: string;
interaction_id: string;
}>;
}>;
return {
title: 'test',
version: 1,
createdTimeMs: Date.now(),
updatedTimeMs: Date.now(),
messages: session.body.interactions
.filter((item) => !item.parent_interaction_id)
.reduce((total, current) => {
const inputItem: IInput = {
type: 'input',
contentType: 'text',
content: current.input,
};
const outputItems: IOutput[] = [
{
type: 'output',
contentType: 'markdown',
content: current.response,
traceId: current.interaction_id,
},
];
return [...total, inputItem, ...outputItems];
}, [] as IMessage[]),
};
}

async getSessions(query: GetSessionsSchema): Promise<ISessionFindResponse> {
throw new Error('Method not implemented.');
}

async saveMessages(
title: string,
sessionId: string | undefined,
messages: IMessage[]
): Promise<{ sessionId: string; messages: IMessage[] }> {
throw new Error('Method not implemented.');
}
deleteSession(sessionId: string): Promise<{}> {
throw new Error('Method not implemented.');
}
updateSession(sessionId: string, title: string): Promise<{}> {
throw new Error('Method not implemented.');
}
}
2 changes: 2 additions & 0 deletions server/services/storage/storage_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ export interface StorageService {
sessionId: string | undefined,
messages: IMessage[]
): Promise<{ sessionId: string; messages: IMessage[] }>;
deleteSession(sessionId: string): Promise<{}>;
updateSession(sessionId: string, title: string): Promise<{}>;
}
Loading