Skip to content

Commit

Permalink
support searching agent by name (opensearch-project#1359)
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Li <[email protected]>
  • Loading branch information
joshuali925 authored Jan 12, 2024
1 parent 38957cd commit 7b4d148
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 60 deletions.
6 changes: 3 additions & 3 deletions server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ const observabilityConfig = {
schema: schema.object({
query_assist: schema.object({
enabled: schema.boolean({ defaultValue: false }),
ppl_agent_id: schema.maybe(schema.string()),
response_summary_agent_id: schema.maybe(schema.string()),
error_summary_agent_id: schema.maybe(schema.string()),
ppl_agent_name: schema.maybe(schema.string()),
response_summary_agent_name: schema.maybe(schema.string()),
error_summary_agent_name: schema.maybe(schema.string()),
}),
}),
};
Expand Down
88 changes: 31 additions & 57 deletions server/routes/query_assist/routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,22 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch';
import { schema } from '@osd/config-schema';
import { ObservabilityConfig } from '../..';
import {
IOpenSearchDashboardsResponse,
IRouter,
ResponseError,
} from '../../../../../src/core/server';
import { ML_COMMONS_API_PREFIX, QUERY_ASSIST_API } from '../../../common/constants/query_assist';
import { QUERY_ASSIST_API } from '../../../common/constants/query_assist';
import { generateFieldContext } from '../../common/helpers/query_assist/generate_field_context';

const AGENT_REQUEST_OPTIONS = {
/**
* It is time-consuming for LLM to generate final answer
* Give it a large timeout window
*/
requestTimeout: 5 * 60 * 1000,
/**
* Do not retry
*/
maxRetries: 0,
};

type AgentResponse = ApiResponse<{
inference_results: Array<{
output: Array<{ name: string; result?: string }>;
}>;
}>;
import { requestWithRetryAgentSearch } from './utils/agents';

export function registerQueryAssistRoutes(router: IRouter, config: ObservabilityConfig) {
const {
ppl_agent_id: pplAgentId,
response_summary_agent_id: responseSummaryAgentId,
error_summary_agent_id: ErrorSummaryAgentId,
ppl_agent_name: pplAgentName,
response_summary_agent_name: responseSummaryAgentName,
error_summary_agent_name: errorSummaryAgentName,
} = config.query_assist;

router.post(
Expand All @@ -54,28 +36,25 @@ export function registerQueryAssistRoutes(router: IRouter, config: Observability
request,
response
): Promise<IOpenSearchDashboardsResponse<any | ResponseError>> => {
if (!pplAgentId)
if (!pplAgentName)
return response.custom({
statusCode: 400,
body:
'PPL agent not found in opensearch_dashboards.yml. Expected observability.query_assist.ppl_agent_id',
'PPL agent name not found in opensearch_dashboards.yml. Expected observability.query_assist.ppl_agent_name',
});

const client = context.core.opensearch.client.asCurrentUser;
try {
const pplRequest = (await client.transport.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${pplAgentId}/_execute`,
body: {
parameters: {
index: request.body.index,
question: request.body.question,
},
const pplRequest = await requestWithRetryAgentSearch({
client,
agentName: pplAgentName,
body: {
parameters: {
index: request.body.index,
question: request.body.question,
},
},
AGENT_REQUEST_OPTIONS
)) as AgentResponse;
});
if (!pplRequest.body.inference_results[0].output[0].result)
throw new Error('Generated PPL query not found.');
const result = JSON.parse(pplRequest.body.inference_results[0].output[0].result) as {
Expand Down Expand Up @@ -116,45 +95,40 @@ export function registerQueryAssistRoutes(router: IRouter, config: Observability
request,
response
): Promise<IOpenSearchDashboardsResponse<any | ResponseError>> => {
if (!responseSummaryAgentId || !ErrorSummaryAgentId)
if (!responseSummaryAgentName || !errorSummaryAgentName)
return response.custom({
statusCode: 400,
body:
'Summary agent not found in opensearch_dashboards.yml. Expected observability.query_assist.response_summary_agent_id and observability.query_assist.error_summary_agent_id',
'Summary agent name not found in opensearch_dashboards.yml. Expected observability.query_assist.response_summary_agent_name and observability.query_assist.error_summary_agent_name',
});

const client = context.core.opensearch.client.asCurrentUser;
const { index, question, query, response: _response, isError } = request.body;
const queryResponse = JSON.stringify(_response);
let summaryRequest: AgentResponse;
let summaryRequest;

try {
if (!isError) {
summaryRequest = (await client.transport.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${responseSummaryAgentId}/_execute`,
body: {
parameters: { index, question, query, response: queryResponse },
},
summaryRequest = await requestWithRetryAgentSearch({
client,
agentName: responseSummaryAgentName,
body: {
parameters: { index, question, query, response: queryResponse },
},
AGENT_REQUEST_OPTIONS
)) as AgentResponse;
});
} else {
const [mappings, sampleDoc] = await Promise.all([
client.indices.getMapping({ index }),
client.search({ index, size: 1 }),
]);
const fields = generateFieldContext(mappings, sampleDoc);
summaryRequest = (await client.transport.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${ErrorSummaryAgentId}/_execute`,
body: {
parameters: { index, question, query, response: queryResponse, fields },
},
summaryRequest = await requestWithRetryAgentSearch({
client,
agentName: errorSummaryAgentName,
body: {
parameters: { index, question, query, response: queryResponse, fields },
},
AGENT_REQUEST_OPTIONS
)) as AgentResponse;
});
}
const summary = summaryRequest.body.inference_results[0].output[0].result;
if (!summary) throw new Error('Generated summary not found.');
Expand Down
125 changes: 125 additions & 0 deletions server/routes/query_assist/utils/__tests__/agents.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { CoreRouteHandlerContext } from '../../../../../../../src/core/server/core_route_handler_context';
import { coreMock, httpServerMock } from '../../../../../../../src/core/server/mocks';
import { agentIdMap, requestWithRetryAgentSearch, searchAgentIdByName } from '../agents';

describe('Agents helper functions', () => {
const coreContext = new CoreRouteHandlerContext(
coreMock.createInternalStart(),
httpServerMock.createOpenSearchDashboardsRequest()
);
const client = coreContext.opensearch.client.asCurrentUser;
const mockedTransport = client.transport.request as jest.Mock;

afterEach(() => {
jest.clearAllMocks();
});

it('searches agent id by name', async () => {
mockedTransport.mockResolvedValueOnce({
body: { hits: { total: { value: 1 }, hits: [{ _id: 'agentId' }] } },
});
const id = await searchAgentIdByName(client, 'test agent');
expect(id).toEqual('agentId');
expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(`
Array [
Object {
"body": Object {
"query": Object {
"term": Object {
"name.keyword": "test agent",
},
},
"sort": Object {
"created_time": "desc",
},
},
"method": "GET",
"path": "/_plugins/_ml/agents/_search",
},
]
`);
});

it('handles not found errors', async () => {
mockedTransport.mockResolvedValueOnce({ body: { hits: { total: 0 } } });
await expect(
searchAgentIdByName(client, 'test agent')
).rejects.toThrowErrorMatchingInlineSnapshot(
`"search agent 'test agent' failed, reason: Error: cannot find any agent by name: test agent"`
);
});

it('handles search errors', async () => {
mockedTransport.mockRejectedValueOnce('request failed');
await expect(
searchAgentIdByName(client, 'test agent')
).rejects.toThrowErrorMatchingInlineSnapshot(
`"search agent 'test agent' failed, reason: request failed"`
);
});

it('requests with valid agent id', async () => {
agentIdMap['test agent'] = 'test-id';
mockedTransport.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
client,
agentName: 'test agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
expect.objectContaining({
path: '/_plugins/_ml/agents/test-id/_execute',
}),
expect.anything()
);
expect(response.body.inference_results[0].output[0].result).toEqual('test response');
});

it('searches for agent id if id is undefined', async () => {
mockedTransport
.mockResolvedValueOnce({ body: { hits: { total: { value: 1 }, hits: [{ _id: 'new-id' }] } } })
.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
client,
agentName: 'new agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
expect.objectContaining({ path: '/_plugins/_ml/agents/new-id/_execute' }),
expect.anything()
);
expect(response.body.inference_results[0].output[0].result).toEqual('test response');
});

it('searches for agent id if id is not found', async () => {
agentIdMap['test agent'] = 'non-exist-agent';
mockedTransport
.mockRejectedValueOnce({ statusCode: 404, body: {}, headers: {} })
.mockResolvedValueOnce({ body: { hits: { total: { value: 1 }, hits: [{ _id: 'new-id' }] } } })
.mockResolvedValueOnce({
body: { inference_results: [{ output: [{ result: 'test response' }] }] },
});
const response = await requestWithRetryAgentSearch({
client,
agentName: 'test agent',
shouldRetryAgentSearch: true,
body: { parameters: { param1: 'value1' } },
});
expect(mockedTransport).toBeCalledWith(
expect.objectContaining({ path: '/_plugins/_ml/agents/new-id/_execute' }),
expect.anything()
);
expect(response.body.inference_results[0].output[0].result).toEqual('test response');
});
});
96 changes: 96 additions & 0 deletions server/routes/query_assist/utils/agents.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { ApiResponse } from '@opensearch-project/opensearch/.';
import { SearchResponse, SearchTotalHits } from '@opensearch-project/opensearch/api/types';
import { RequestBody } from '@opensearch-project/opensearch/lib/Transport';
import { OpenSearchClient } from '../../../../../../src/core/server';
import { isResponseError } from '../../../../../../src/core/server/opensearch/client/errors';
import { ML_COMMONS_API_PREFIX } from '../../../../common/constants/query_assist';

const AGENT_REQUEST_OPTIONS = {
/**
* It is time-consuming for LLM to generate final answer
* Give it a large timeout window
*/
requestTimeout: 5 * 60 * 1000,
/**
* Do not retry
*/
maxRetries: 0,
};

type AgentResponse = ApiResponse<{
inference_results: Array<{
output: Array<{ name: string; result?: string }>;
}>;
}>;

export const agentIdMap: Record<string, string> = {};

export const searchAgentIdByName = async (
opensearchClient: OpenSearchClient,
name: string
): Promise<string> => {
try {
const response = (await opensearchClient.transport.request({
method: 'GET',
path: `${ML_COMMONS_API_PREFIX}/agents/_search`,
body: {
query: {
term: {
'name.keyword': name,
},
},
sort: {
created_time: 'desc',
},
},
})) as ApiResponse<SearchResponse>;

if (
!response ||
(typeof response.body.hits.total === 'number' && response.body.hits.total === 0) ||
(response.body.hits.total as SearchTotalHits).value === 0
) {
throw new Error('cannot find any agent by name: ' + name);
}
const id = response.body.hits.hits[0]._id;
return id;
} catch (error) {
const errorMessage = JSON.stringify(error.meta?.body) || error;
throw new Error(`search agent '${name}' failed, reason: ` + errorMessage);
}
};

export const requestWithRetryAgentSearch = async (options: {
client: OpenSearchClient;
agentName: string;
shouldRetryAgentSearch?: boolean;
body: RequestBody;
}): Promise<AgentResponse> => {
const { client, agentName, shouldRetryAgentSearch = true, body } = options;
let retry = shouldRetryAgentSearch;
if (!agentIdMap[agentName]) {
agentIdMap[agentName] = await searchAgentIdByName(client, agentName);
retry = false;
}
return client.transport
.request(
{
method: 'POST',
path: `${ML_COMMONS_API_PREFIX}/agents/${agentIdMap[agentName]}/_execute`,
body,
},
AGENT_REQUEST_OPTIONS
)
.catch(async (error) => {
if (retry && isResponseError(error) && error.statusCode === 404) {
agentIdMap[agentName] = await searchAgentIdByName(client, agentName);
return requestWithRetryAgentSearch({ ...options, shouldRetryAgentSearch: false });
}
return Promise.reject(error);
}) as Promise<AgentResponse>;
};

0 comments on commit 7b4d148

Please sign in to comment.