diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java new file mode 100644 index 0000000000..79d339e4fc --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.opensearch.ml.common.CommonValue.*; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Getter +@Setter +@Log4j2 +@ToolAnnotation(SearchIndexTool.TYPE) +public class SearchIndexTool implements Tool { + + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String QUERY_FIELD = "query"; + + public static final String TYPE = "SearchIndexTool"; + private static final String DEFAULT_DESCRIPTION = + "Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when both index name and DSL query is available."; + + private String name = TYPE; + + private String description = DEFAULT_DESCRIPTION; + + private Client client; + + private NamedXContentRegistry xContentRegistry; + + public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + @Override + public boolean validate(Map parameters) { + return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null; + } + + private SearchRequest getSearchRequest(String index, String query) throws IOException { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); + searchSourceBuilder.parseXContent(queryParser); + return new SearchRequest().source(searchSourceBuilder).indices(index); + } + + private static Map processResponse(SearchHit hit) { + Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); + docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); + docContent.put("_source", hit.getSourceAsMap()); + return docContent; + } + + @Override + public void run(Map parameters, ActionListener listener) { + try { + String input = parameters.get(INPUT_FIELD); + JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class); + String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null); + String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null); + if (index == null || query == null) { + listener.onFailure(new IllegalArgumentException("SearchIndexTool's two parameter: index and query are required!")); + return; + } + SearchRequest searchRequest = getSearchRequest(index, query); + + ActionListener actionListener = ActionListener.wrap(r -> { + SearchHit[] hits = r.getHits().getHits(); + + if (hits != null && hits.length > 0) { + StringBuilder contextBuilder = new StringBuilder(); + for (SearchHit hit : hits) { + String doc = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map docContent = processResponse(hit); + return StringUtils.gson.toJson(docContent); + }); + contextBuilder.append(doc).append("\n"); + } + listener.onResponse((T) contextBuilder.toString()); + } else { + listener.onResponse((T) ""); + } + }, e -> { + log.error("Failed to search index", e); + listener.onFailure(e); + }); + + // since searching connector and model needs access control, we need + // to forward the request corresponding transport action + if (Objects.equals(index, ML_CONNECTOR_INDEX)) { + client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener); + } else if (Objects.equals(index, ML_MODEL_INDEX)) { + client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener); + } else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) { + client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener); + } else { + client.search(searchRequest, actionListener); + } + } catch (Exception e) { + log.error("Failed to search index", e); + listener.onFailure(e); + } + } + + public static class Factory implements Tool.Factory { + + private Client client; + private static Factory INSTANCE; + + private NamedXContentRegistry xContentRegistry; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchIndexTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public SearchIndexTool create(Map params) { + return new SearchIndexTool(client, xContentRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java new file mode 100644 index 0000000000..4ccbd33a17 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java @@ -0,0 +1,188 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + +import java.io.InputStream; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.search.SearchModule; + +import lombok.SneakyThrows; + +public class SearchIndexToolTests { + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + + private Client client; + + private SearchIndexTool mockedSearchIndexTool; + + private String mockedSearchResponseString; + + @Before + @SneakyThrows + public void setup() { + client = mock(Client.class); + mockedSearchIndexTool = mock( + SearchIndexTool.class, + Mockito.withSettings().useConstructor(client, TEST_XCONTENT_REGISTRY_FOR_QUERY).defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + + try (InputStream searchResponseIns = SearchIndexTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedSearchResponseString = new String(searchResponseIns.readAllBytes()); + } + } + } + + @Test + @SneakyThrows + public void testGetType() { + String type = mockedSearchIndexTool.getType(); + assertFalse(Strings.isNullOrEmpty(type)); + assertEquals("SearchIndexTool", type); + } + + @Test + @SneakyThrows + public void testValidate() { + Map parameters = Map.of("input", "{}"); + assertTrue(mockedSearchIndexTool.validate(parameters)); + } + + @Test + @SneakyThrows + public void testValidateWithEmptyInput() { + Map parameters = Map.of(); + assertFalse(mockedSearchIndexTool.validate(parameters)); + } + + @Test + public void testRunWithNormalIndex() { + String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, times(1)).search(any(), any()); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + } + + @Test + public void testRunWithConnectorIndex() { + String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"query\": {\"match_all\": {}}}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, never()).search(any(), any()); + Mockito.verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any()); + } + + @Test + public void testRunWithModelIndex() { + String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"query\": {\"match_all\": {}}}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, never()).search(any(), any()); + Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); + } + + @Test + public void testRunWithModelGroupIndex() { + String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"query\": {\"match_all\": {}}}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, never()).search(any(), any()); + Mockito.verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), any(), any()); + } + + @Test + @SneakyThrows + public void testRunWithSearchResults() { + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, listener); + + future.join(); + + Mockito.verify(client, times(1)).search(any(), any()); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + } + + @Test + @SneakyThrows + public void testRunWithEmptyQuery() { + String inputString = "{\"index\": \"test_index\"}"; + Map parameters = Map.of("input", inputString); + ActionListener listener = mock(ActionListener.class); + mockedSearchIndexTool.run(parameters, listener); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + Mockito.verify(client, Mockito.never()).search(any(), any()); + } + + @Test + public void testRunWithInvalidQuery() { + String inputString = "{\"index\": \"test-index\", \"query\": \"invalid query\"}"; + Map parameters = Map.of("input", inputString); + ActionListener listener = mock(ActionListener.class); + mockedSearchIndexTool.run(parameters, listener); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + Mockito.verify(client, Mockito.never()).search(any(), any()); + } + + @Test + public void testRunWithEmptyQueryBody() { + String inputString = "{\"index\": \"test-index\", \"query\": {}}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + Mockito.verify(client, times(1)).search(any(), any()); + Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); + } + + @Test + public void testFactory() { + SearchIndexTool searchIndexTool = SearchIndexTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchIndexTool.TYPE, searchIndexTool.getType()); + } +} diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json new file mode 100644 index 0000000000..d89ad3b0d9 --- /dev/null +++ b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/tools/retrieval_tool_search_response.json @@ -0,0 +1,35 @@ +{ + "took": 201, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 2, + "relation": "eq" + }, + "max_score": 89.2917, + "hits": [ + { + "_index": "hybrid-index", + "_id": "1", + "_score": 89.2917, + "_source": { + "passage_text": "Company test_mock have a history of 100 years." + } + }, + { + "_index": "hybrid-index", + "_id": "2", + "_score": 0.10702579, + "_source": { + "passage_text": "the price of the api is 2$ per invocation" + } + } + ] + } +} \ No newline at end of file diff --git a/plugin/build.gradle b/plugin/build.gradle index dbfa0c4f3f..e1c9e49100 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -184,6 +184,14 @@ integTest { // Set this to true this if you want to see the logs in the terminal test output. // note: if left false the log output will still show in your IDE testLogging.showStandardStreams = true + + // Exclude integration test for search index tool for multiple nodes + // because we cannot get concrete exception message in multi-node cluster + if (_numNodes > 1) { + filter { + excludeTestsMatching "org.opensearch.ml.rest.RestSearchIndexToolIT.*" + } + } } testClusters.integTest { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 7a958ff64a..f2f2521da2 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -169,6 +169,7 @@ import org.opensearch.ml.engine.tools.CatIndexTool; import org.opensearch.ml.engine.tools.IndexMappingTool; import org.opensearch.ml.engine.tools.MLModelTool; +import org.opensearch.ml.engine.tools.SearchIndexTool; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -566,11 +567,13 @@ public Collection createComponents( AgentTool.Factory.getInstance().init(client); CatIndexTool.Factory.getInstance().init(client, clusterService); IndexMappingTool.Factory.getInstance().init(client); + SearchIndexTool.Factory.getInstance().init(client, xContentRegistry); toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance()); toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance()); toolFactories.put(CatIndexTool.TYPE, CatIndexTool.Factory.getInstance()); toolFactories.put(IndexMappingTool.TYPE, IndexMappingTool.Factory.getInstance()); + toolFactories.put(SearchIndexTool.TYPE, SearchIndexTool.Factory.getInstance()); if (externalToolFactories != null) { toolFactories.putAll(externalToolFactories); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBaseAgentToolsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBaseAgentToolsIT.java new file mode 100644 index 0000000000..bd1d2e431f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBaseAgentToolsIT.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.hc.core5.http.ParseException; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.utils.TestHelper; + +public abstract class RestBaseAgentToolsIT extends MLCommonsRestTestCase { + + private static final String INTERNAL_INDICES_PREFIX = "."; + + private Object parseFieldFromResponse(Response response, String field) throws IOException, ParseException { + assertNotNull(field); + Map map = parseResponseToMap(response); + Object result = map.get(field); + assertNotNull(result); + return result; + } + + protected void createIndexWithConfiguration(String indexName, String indexConfiguration) throws Exception { + Response response = TestHelper.makeRequest(client(), "PUT", indexName, null, indexConfiguration, null); + Map responseInMap = parseResponseToMap(response); + assertEquals("true", responseInMap.get("acknowledged").toString()); + assertEquals(indexName, responseInMap.get("index").toString()); + } + + protected void addDocToIndex(String indexName, String docId, List fieldNames, List fieldContents) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldNames.size(); i++) { + builder.field(fieldNames.get(i), fieldContents.get(i)); + } + builder.endObject(); + Response response = TestHelper + .makeRequest(client(), "POST", "/" + indexName + "/_doc/" + docId + "?refresh=true", null, builder.toString(), null); + assertEquals(RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected String createAgent(String requestBody) throws IOException, ParseException { + Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/agents/_register", null, requestBody, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, AgentMLInput.AGENT_ID_FIELD).toString(); + } + + private String parseStringResponseFromExecuteAgentResponse(Response response) throws IOException, ParseException { + Map responseInMap = parseResponseToMap(response); + Optional optionalResult = Optional + .ofNullable(responseInMap) + .map(m -> (List) m.get(ModelTensorOutput.INFERENCE_RESULT_FIELD)) + .map(l -> (Map) l.get(0)) + .map(m -> (List) m.get(ModelTensors.OUTPUT_FIELD)) + .map(l -> (Map) l.get(0)) + .map(m -> (String) (m.get(ModelTensor.RESULT_FIELD))); + return optionalResult.get(); + } + + // execute the agent, and return the String response from the json structure + // {"inference_results": [{"output": [{"name": "response","result": "the result to return."}]}]} + protected String executeAgent(String agentId, String requestBody) throws IOException, ParseException { + Response response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, requestBody, null); + return parseStringResponseFromExecuteAgentResponse(response); + } + + @After + public void deleteExternalIndices() throws IOException { + final Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json" + "&expand_wildcards=all")); + final MediaType xContentType = MediaType.fromMediaType(response.getEntity().getContentType()); + try ( + final XContentParser parser = xContentType + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent() + ) + ) { + final XContentParser.Token token = parser.nextToken(); + final List> parserList; + if (token == XContentParser.Token.START_ARRAY) { + parserList = parser.listOrderedMap().stream().map(obj -> (Map) obj).collect(Collectors.toList()); + } else { + parserList = Collections.singletonList(parser.mapOrdered()); + } + + final List externalIndices = parserList + .stream() + .map(index -> (String) index.get("index")) + .filter(indexName -> indexName != null) + .filter(indexName -> !indexName.startsWith(INTERNAL_INDICES_PREFIX)) + .collect(Collectors.toList()); + + for (final String indexName : externalIndices) { + adminClient().performRequest(new Request("DELETE", "/" + indexName)); + } + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestSearchIndexToolIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestSearchIndexToolIT.java new file mode 100644 index 0000000000..6330037401 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestSearchIndexToolIT.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Objects; + +import org.apache.hc.core5.http.ParseException; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; + +public class RestSearchIndexToolIT extends RestBaseAgentToolsIT { + public static String TEST_INDEX_NAME = "test_index"; + private String registerAgentRequestBody; + + private void prepareIndex() throws Exception { + createIndexWithConfiguration( + TEST_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(TEST_INDEX_NAME, "0", List.of("text"), List.of("text doc 1")); + addDocToIndex(TEST_INDEX_NAME, "1", List.of("text"), List.of("text doc 2")); + addDocToIndex(TEST_INDEX_NAME, "2", List.of("text"), List.of("text doc 3")); + } + + @Before + public void setUp() throws Exception { + super.setUp(); + prepareIndex(); + registerAgentRequestBody = Files + .readString( + Path + .of( + Objects + .requireNonNull( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/ml/rest/tools/register_flow_agent_of_search_index_tool_request_body.json") + ) + .toURI() + ) + ); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + deleteExternalIndices(); + } + + public void testSearchIndexToolInFlowAgent_withMatchAllQuery() throws IOException, ParseException { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": {\n" + + " \"index\": \"test_index\",\n" + + " \"query\": {\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + " }\n" + + " } \n" + + " }\n" + + "}\n"; + String result = executeAgent(agentId, agentInput); + assertEquals( + "The search index result not equal with expected.", + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 1\"},\"_id\":\"0\",\"_score\":1.0}\n" + + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.0}\n" + + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":1.0}\n", + result + ); + } + + public void testSearchIndexToolInFlowAgent_withEmptyIndexField_thenThrowException() throws IOException, ParseException { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": {\n" + + " \"query\": {\n" + + " \"query\": {\n" + + " \"match_all\": {}\n" + + " }\n" + + " }\n" + + " } \n" + + " }\n" + + "}\n"; + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); + MatcherAssert.assertThat(exception.getMessage(), containsString("SearchIndexTool's two parameter: index and query are required!")); + } + + public void testSearchIndexToolInFlowAgent_withEmptyQueryField_thenThrowException() throws IOException, ParseException { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": {\n" + + " \"index\": \"test_index\"\n" + + " } \n" + + " }\n" + + "}\n"; + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); + MatcherAssert.assertThat(exception.getMessage(), containsString("SearchIndexTool's two parameter: index and query are required!")); + } + + public void testSearchIndexToolInFlowAgent_withIllegalQueryField_thenThrowException() throws IOException, ParseException { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": {\n" + + " \"index\": \"test_index\",\n" + + " \"query\": \"Invalid Query\"\n" + + " } \n" + + " }\n" + + "}\n"; + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); + MatcherAssert.assertThat(exception.getMessage(), containsString("ParsingException")); + } +} diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/tools/register_flow_agent_of_search_index_tool_request_body.json b/plugin/src/test/resources/org/opensearch/ml/rest/tools/register_flow_agent_of_search_index_tool_request_body.json new file mode 100644 index 0000000000..52a6707390 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/tools/register_flow_agent_of_search_index_tool_request_body.json @@ -0,0 +1,10 @@ +{ + "name": "Test_Search_Index_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchIndexTool", + "description": "Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when a DSL query is available." + } + ] +} \ No newline at end of file