From 7b4d1480caf3d150d06cff518afcc4fe83fd9a3a Mon Sep 17 00:00:00 2001 From: Joshua Li Date: Fri, 12 Jan 2024 13:49:40 -0800 Subject: [PATCH] support searching agent by name (#1359) Signed-off-by: Joshua Li --- server/index.ts | 6 +- server/routes/query_assist/routes.ts | 88 +++++------- .../utils/__tests__/agents.test.ts | 125 ++++++++++++++++++ server/routes/query_assist/utils/agents.ts | 96 ++++++++++++++ 4 files changed, 255 insertions(+), 60 deletions(-) create mode 100644 server/routes/query_assist/utils/__tests__/agents.test.ts create mode 100644 server/routes/query_assist/utils/agents.ts diff --git a/server/index.ts b/server/index.ts index 5281e7b8e..aaecf7e92 100644 --- a/server/index.ts +++ b/server/index.ts @@ -17,9 +17,9 @@ const observabilityConfig = { schema: schema.object({ query_assist: schema.object({ enabled: schema.boolean({ defaultValue: false }), - ppl_agent_id: schema.maybe(schema.string()), - response_summary_agent_id: schema.maybe(schema.string()), - error_summary_agent_id: schema.maybe(schema.string()), + ppl_agent_name: schema.maybe(schema.string()), + response_summary_agent_name: schema.maybe(schema.string()), + error_summary_agent_name: schema.maybe(schema.string()), }), }), }; diff --git a/server/routes/query_assist/routes.ts b/server/routes/query_assist/routes.ts index edf030742..a9bd90348 100644 --- a/server/routes/query_assist/routes.ts +++ b/server/routes/query_assist/routes.ts @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { ApiResponse } from '@opensearch-project/opensearch'; import { schema } from '@osd/config-schema'; import { ObservabilityConfig } from '../..'; import { @@ -11,32 +10,15 @@ import { IRouter, ResponseError, } from '../../../../../src/core/server'; -import { ML_COMMONS_API_PREFIX, QUERY_ASSIST_API } from '../../../common/constants/query_assist'; +import { QUERY_ASSIST_API } from '../../../common/constants/query_assist'; import { generateFieldContext } from '../../common/helpers/query_assist/generate_field_context'; - -const AGENT_REQUEST_OPTIONS = { - /** - * It is time-consuming for LLM to generate final answer - * Give it a large timeout window - */ - requestTimeout: 5 * 60 * 1000, - /** - * Do not retry - */ - maxRetries: 0, -}; - -type AgentResponse = ApiResponse<{ - inference_results: Array<{ - output: Array<{ name: string; result?: string }>; - }>; -}>; +import { requestWithRetryAgentSearch } from './utils/agents'; export function registerQueryAssistRoutes(router: IRouter, config: ObservabilityConfig) { const { - ppl_agent_id: pplAgentId, - response_summary_agent_id: responseSummaryAgentId, - error_summary_agent_id: ErrorSummaryAgentId, + ppl_agent_name: pplAgentName, + response_summary_agent_name: responseSummaryAgentName, + error_summary_agent_name: errorSummaryAgentName, } = config.query_assist; router.post( @@ -54,28 +36,25 @@ export function registerQueryAssistRoutes(router: IRouter, config: Observability request, response ): Promise> => { - if (!pplAgentId) + if (!pplAgentName) return response.custom({ statusCode: 400, body: - 'PPL agent not found in opensearch_dashboards.yml. Expected observability.query_assist.ppl_agent_id', + 'PPL agent name not found in opensearch_dashboards.yml. Expected observability.query_assist.ppl_agent_name', }); const client = context.core.opensearch.client.asCurrentUser; try { - const pplRequest = (await client.transport.request( - { - method: 'POST', - path: `${ML_COMMONS_API_PREFIX}/agents/${pplAgentId}/_execute`, - body: { - parameters: { - index: request.body.index, - question: request.body.question, - }, + const pplRequest = await requestWithRetryAgentSearch({ + client, + agentName: pplAgentName, + body: { + parameters: { + index: request.body.index, + question: request.body.question, }, }, - AGENT_REQUEST_OPTIONS - )) as AgentResponse; + }); if (!pplRequest.body.inference_results[0].output[0].result) throw new Error('Generated PPL query not found.'); const result = JSON.parse(pplRequest.body.inference_results[0].output[0].result) as { @@ -116,45 +95,40 @@ export function registerQueryAssistRoutes(router: IRouter, config: Observability request, response ): Promise> => { - if (!responseSummaryAgentId || !ErrorSummaryAgentId) + if (!responseSummaryAgentName || !errorSummaryAgentName) return response.custom({ statusCode: 400, body: - 'Summary agent not found in opensearch_dashboards.yml. Expected observability.query_assist.response_summary_agent_id and observability.query_assist.error_summary_agent_id', + 'Summary agent name not found in opensearch_dashboards.yml. Expected observability.query_assist.response_summary_agent_name and observability.query_assist.error_summary_agent_name', }); const client = context.core.opensearch.client.asCurrentUser; const { index, question, query, response: _response, isError } = request.body; const queryResponse = JSON.stringify(_response); - let summaryRequest: AgentResponse; + let summaryRequest; + try { if (!isError) { - summaryRequest = (await client.transport.request( - { - method: 'POST', - path: `${ML_COMMONS_API_PREFIX}/agents/${responseSummaryAgentId}/_execute`, - body: { - parameters: { index, question, query, response: queryResponse }, - }, + summaryRequest = await requestWithRetryAgentSearch({ + client, + agentName: responseSummaryAgentName, + body: { + parameters: { index, question, query, response: queryResponse }, }, - AGENT_REQUEST_OPTIONS - )) as AgentResponse; + }); } else { const [mappings, sampleDoc] = await Promise.all([ client.indices.getMapping({ index }), client.search({ index, size: 1 }), ]); const fields = generateFieldContext(mappings, sampleDoc); - summaryRequest = (await client.transport.request( - { - method: 'POST', - path: `${ML_COMMONS_API_PREFIX}/agents/${ErrorSummaryAgentId}/_execute`, - body: { - parameters: { index, question, query, response: queryResponse, fields }, - }, + summaryRequest = await requestWithRetryAgentSearch({ + client, + agentName: errorSummaryAgentName, + body: { + parameters: { index, question, query, response: queryResponse, fields }, }, - AGENT_REQUEST_OPTIONS - )) as AgentResponse; + }); } const summary = summaryRequest.body.inference_results[0].output[0].result; if (!summary) throw new Error('Generated summary not found.'); diff --git a/server/routes/query_assist/utils/__tests__/agents.test.ts b/server/routes/query_assist/utils/__tests__/agents.test.ts new file mode 100644 index 000000000..713299fa3 --- /dev/null +++ b/server/routes/query_assist/utils/__tests__/agents.test.ts @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { CoreRouteHandlerContext } from '../../../../../../../src/core/server/core_route_handler_context'; +import { coreMock, httpServerMock } from '../../../../../../../src/core/server/mocks'; +import { agentIdMap, requestWithRetryAgentSearch, searchAgentIdByName } from '../agents'; + +describe('Agents helper functions', () => { + const coreContext = new CoreRouteHandlerContext( + coreMock.createInternalStart(), + httpServerMock.createOpenSearchDashboardsRequest() + ); + const client = coreContext.opensearch.client.asCurrentUser; + const mockedTransport = client.transport.request as jest.Mock; + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('searches agent id by name', async () => { + mockedTransport.mockResolvedValueOnce({ + body: { hits: { total: { value: 1 }, hits: [{ _id: 'agentId' }] } }, + }); + const id = await searchAgentIdByName(client, 'test agent'); + expect(id).toEqual('agentId'); + expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(` + Array [ + Object { + "body": Object { + "query": Object { + "term": Object { + "name.keyword": "test agent", + }, + }, + "sort": Object { + "created_time": "desc", + }, + }, + "method": "GET", + "path": "/_plugins/_ml/agents/_search", + }, + ] + `); + }); + + it('handles not found errors', async () => { + mockedTransport.mockResolvedValueOnce({ body: { hits: { total: 0 } } }); + await expect( + searchAgentIdByName(client, 'test agent') + ).rejects.toThrowErrorMatchingInlineSnapshot( + `"search agent 'test agent' failed, reason: Error: cannot find any agent by name: test agent"` + ); + }); + + it('handles search errors', async () => { + mockedTransport.mockRejectedValueOnce('request failed'); + await expect( + searchAgentIdByName(client, 'test agent') + ).rejects.toThrowErrorMatchingInlineSnapshot( + `"search agent 'test agent' failed, reason: request failed"` + ); + }); + + it('requests with valid agent id', async () => { + agentIdMap['test agent'] = 'test-id'; + mockedTransport.mockResolvedValueOnce({ + body: { inference_results: [{ output: [{ result: 'test response' }] }] }, + }); + const response = await requestWithRetryAgentSearch({ + client, + agentName: 'test agent', + shouldRetryAgentSearch: true, + body: { parameters: { param1: 'value1' } }, + }); + expect(mockedTransport).toBeCalledWith( + expect.objectContaining({ + path: '/_plugins/_ml/agents/test-id/_execute', + }), + expect.anything() + ); + expect(response.body.inference_results[0].output[0].result).toEqual('test response'); + }); + + it('searches for agent id if id is undefined', async () => { + mockedTransport + .mockResolvedValueOnce({ body: { hits: { total: { value: 1 }, hits: [{ _id: 'new-id' }] } } }) + .mockResolvedValueOnce({ + body: { inference_results: [{ output: [{ result: 'test response' }] }] }, + }); + const response = await requestWithRetryAgentSearch({ + client, + agentName: 'new agent', + shouldRetryAgentSearch: true, + body: { parameters: { param1: 'value1' } }, + }); + expect(mockedTransport).toBeCalledWith( + expect.objectContaining({ path: '/_plugins/_ml/agents/new-id/_execute' }), + expect.anything() + ); + expect(response.body.inference_results[0].output[0].result).toEqual('test response'); + }); + + it('searches for agent id if id is not found', async () => { + agentIdMap['test agent'] = 'non-exist-agent'; + mockedTransport + .mockRejectedValueOnce({ statusCode: 404, body: {}, headers: {} }) + .mockResolvedValueOnce({ body: { hits: { total: { value: 1 }, hits: [{ _id: 'new-id' }] } } }) + .mockResolvedValueOnce({ + body: { inference_results: [{ output: [{ result: 'test response' }] }] }, + }); + const response = await requestWithRetryAgentSearch({ + client, + agentName: 'test agent', + shouldRetryAgentSearch: true, + body: { parameters: { param1: 'value1' } }, + }); + expect(mockedTransport).toBeCalledWith( + expect.objectContaining({ path: '/_plugins/_ml/agents/new-id/_execute' }), + expect.anything() + ); + expect(response.body.inference_results[0].output[0].result).toEqual('test response'); + }); +}); diff --git a/server/routes/query_assist/utils/agents.ts b/server/routes/query_assist/utils/agents.ts new file mode 100644 index 000000000..ae909790b --- /dev/null +++ b/server/routes/query_assist/utils/agents.ts @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ApiResponse } from '@opensearch-project/opensearch/.'; +import { SearchResponse, SearchTotalHits } from '@opensearch-project/opensearch/api/types'; +import { RequestBody } from '@opensearch-project/opensearch/lib/Transport'; +import { OpenSearchClient } from '../../../../../../src/core/server'; +import { isResponseError } from '../../../../../../src/core/server/opensearch/client/errors'; +import { ML_COMMONS_API_PREFIX } from '../../../../common/constants/query_assist'; + +const AGENT_REQUEST_OPTIONS = { + /** + * It is time-consuming for LLM to generate final answer + * Give it a large timeout window + */ + requestTimeout: 5 * 60 * 1000, + /** + * Do not retry + */ + maxRetries: 0, +}; + +type AgentResponse = ApiResponse<{ + inference_results: Array<{ + output: Array<{ name: string; result?: string }>; + }>; +}>; + +export const agentIdMap: Record = {}; + +export const searchAgentIdByName = async ( + opensearchClient: OpenSearchClient, + name: string +): Promise => { + try { + const response = (await opensearchClient.transport.request({ + method: 'GET', + path: `${ML_COMMONS_API_PREFIX}/agents/_search`, + body: { + query: { + term: { + 'name.keyword': name, + }, + }, + sort: { + created_time: 'desc', + }, + }, + })) as ApiResponse; + + if ( + !response || + (typeof response.body.hits.total === 'number' && response.body.hits.total === 0) || + (response.body.hits.total as SearchTotalHits).value === 0 + ) { + throw new Error('cannot find any agent by name: ' + name); + } + const id = response.body.hits.hits[0]._id; + return id; + } catch (error) { + const errorMessage = JSON.stringify(error.meta?.body) || error; + throw new Error(`search agent '${name}' failed, reason: ` + errorMessage); + } +}; + +export const requestWithRetryAgentSearch = async (options: { + client: OpenSearchClient; + agentName: string; + shouldRetryAgentSearch?: boolean; + body: RequestBody; +}): Promise => { + const { client, agentName, shouldRetryAgentSearch = true, body } = options; + let retry = shouldRetryAgentSearch; + if (!agentIdMap[agentName]) { + agentIdMap[agentName] = await searchAgentIdByName(client, agentName); + retry = false; + } + return client.transport + .request( + { + method: 'POST', + path: `${ML_COMMONS_API_PREFIX}/agents/${agentIdMap[agentName]}/_execute`, + body, + }, + AGENT_REQUEST_OPTIONS + ) + .catch(async (error) => { + if (retry && isResponseError(error) && error.statusCode === 404) { + agentIdMap[agentName] = await searchAgentIdByName(client, agentName); + return requestWithRetryAgentSearch({ ...options, shouldRetryAgentSearch: false }); + } + return Promise.reject(error); + }) as Promise; +};