Skip to content

Commit

Permalink
feat: Add search index tool (#2356)
Browse files Browse the repository at this point in the history
* add search index tool

Signed-off-by: yuye-aws <[email protected]>

* add search index tool

Signed-off-by: yuye-aws <[email protected]>

* spotless apply

Signed-off-by: yuye-aws <[email protected]>

* add integration test for search index tool

Signed-off-by: yuye-aws <[email protected]>

* fix unit test error

Signed-off-by: yuye-aws <[email protected]>

* run spotless apply

Signed-off-by: yuye-aws <[email protected]>

* change json file path for search index tool it

Signed-off-by: yuye-aws <[email protected]>

* fix integration test for search index tool

Signed-off-by: yuye-aws <[email protected]>

* run spotless apply

Signed-off-by: yuye-aws <[email protected]>

* skip multiple node integration test for search index tool

Signed-off-by: yuye-aws <[email protected]>

* resolve code duplication in search index tool it

Signed-off-by: yuye-aws <[email protected]>

* change method access

Signed-off-by: yuye-aws <[email protected]>

* remove unused class

Signed-off-by: yuye-aws <[email protected]>

* remove wild card import

Signed-off-by: yuye-aws <[email protected]>

* add comment in gradle file

Signed-off-by: yuye-aws <[email protected]>

---------

Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws authored Apr 26, 2024
1 parent 05b0e5e commit ea7fefa
Show file tree
Hide file tree
Showing 8 changed files with 701 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.CommonValue.*;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Getter
@Setter
@Log4j2
@ToolAnnotation(SearchIndexTool.TYPE)
public class SearchIndexTool implements Tool {

public static final String INPUT_FIELD = "input";
public static final String INDEX_FIELD = "index";
public static final String QUERY_FIELD = "query";

public static final String TYPE = "SearchIndexTool";
private static final String DEFAULT_DESCRIPTION =
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when both index name and DSL query is available.";

private String name = TYPE;

private String description = DEFAULT_DESCRIPTION;

private Client client;

private NamedXContentRegistry xContentRegistry;

public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null;
}

private SearchRequest getSearchRequest(String index, String query) throws IOException {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query);
searchSourceBuilder.parseXContent(queryParser);
return new SearchRequest().source(searchSourceBuilder).indices(index);
}

private static Map<String, Object> processResponse(SearchHit hit) {
Map<String, Object> docContent = new HashMap<>();
docContent.put("_index", hit.getIndex());
docContent.put("_id", hit.getId());
docContent.put("_score", hit.getScore());
docContent.put("_source", hit.getSourceAsMap());
return docContent;
}

@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String input = parameters.get(INPUT_FIELD);
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class);
String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null);
String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null);
if (index == null || query == null) {
listener.onFailure(new IllegalArgumentException("SearchIndexTool's two parameter: index and query are required!"));
return;
}
SearchRequest searchRequest = getSearchRequest(index, query);

ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> {
SearchHit[] hits = r.getHits().getHits();

if (hits != null && hits.length > 0) {
StringBuilder contextBuilder = new StringBuilder();
for (SearchHit hit : hits) {
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
Map<String, Object> docContent = processResponse(hit);
return StringUtils.gson.toJson(docContent);
});
contextBuilder.append(doc).append("\n");
}
listener.onResponse((T) contextBuilder.toString());
} else {
listener.onResponse((T) "");
}
}, e -> {
log.error("Failed to search index", e);
listener.onFailure(e);
});

// since searching connector and model needs access control, we need
// to forward the request corresponding transport action
if (Objects.equals(index, ML_CONNECTOR_INDEX)) {
client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_INDEX)) {
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener);
} else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) {
client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener);
} else {
client.search(searchRequest, actionListener);
}
} catch (Exception e) {
log.error("Failed to search index", e);
listener.onFailure(e);
}
}

public static class Factory implements Tool.Factory<SearchIndexTool> {

private Client client;
private static Factory INSTANCE;

private NamedXContentRegistry xContentRegistry;

/**
* Create or return the singleton factory instance
*/
public static Factory getInstance() {
if (INSTANCE != null) {
return INSTANCE;
}
synchronized (SearchIndexTool.class) {
if (INSTANCE != null) {
return INSTANCE;
}
INSTANCE = new Factory();
return INSTANCE;
}
}

public void init(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
public SearchIndexTool create(Map<String, Object> params) {
return new SearchIndexTool(client, xContentRegistry);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}

@Override
public String getDefaultType() {
return TYPE;
}

@Override
public String getDefaultVersion() {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.tools;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;

import java.io.InputStream;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.search.SearchModule;

import lombok.SneakyThrows;

public class SearchIndexToolTests {
static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry(
new SearchModule(Settings.EMPTY, List.of()).getNamedXContents()
);

private Client client;

private SearchIndexTool mockedSearchIndexTool;

private String mockedSearchResponseString;

@Before
@SneakyThrows
public void setup() {
client = mock(Client.class);
mockedSearchIndexTool = mock(
SearchIndexTool.class,
Mockito.withSettings().useConstructor(client, TEST_XCONTENT_REGISTRY_FOR_QUERY).defaultAnswer(Mockito.CALLS_REAL_METHODS)
);

try (InputStream searchResponseIns = SearchIndexTool.class.getResourceAsStream("retrieval_tool_search_response.json")) {
if (searchResponseIns != null) {
mockedSearchResponseString = new String(searchResponseIns.readAllBytes());
}
}
}

@Test
@SneakyThrows
public void testGetType() {
String type = mockedSearchIndexTool.getType();
assertFalse(Strings.isNullOrEmpty(type));
assertEquals("SearchIndexTool", type);
}

@Test
@SneakyThrows
public void testValidate() {
Map<String, String> parameters = Map.of("input", "{}");
assertTrue(mockedSearchIndexTool.validate(parameters));
}

@Test
@SneakyThrows
public void testValidateWithEmptyInput() {
Map<String, String> parameters = Map.of();
assertFalse(mockedSearchIndexTool.validate(parameters));
}

@Test
public void testRunWithNormalIndex() {
String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, times(1)).search(any(), any());
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
}

@Test
public void testRunWithConnectorIndex() {
String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"query\": {\"match_all\": {}}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Mockito.verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any());
}

@Test
public void testRunWithModelIndex() {
String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"query\": {\"match_all\": {}}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any());
}

@Test
public void testRunWithModelGroupIndex() {
String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"query\": {\"match_all\": {}}}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, never()).search(any(), any());
Mockito.verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), any(), any());
}

@Test
@SneakyThrows
public void testRunWithSearchResults() {
SearchResponse mockedSearchResponse = SearchResponse
.fromXContent(
JsonXContent.jsonXContent
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString)
);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(mockedSearchResponse);
return null;
}).when(client).search(any(), any());

String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}";
final CompletableFuture<String> future = new CompletableFuture<>();
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); });
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, listener);

future.join();

Mockito.verify(client, times(1)).search(any(), any());
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
}

@Test
@SneakyThrows
public void testRunWithEmptyQuery() {
String inputString = "{\"index\": \"test_index\"}";
Map<String, String> parameters = Map.of("input", inputString);
ActionListener<String> listener = mock(ActionListener.class);
mockedSearchIndexTool.run(parameters, listener);
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
Mockito.verify(client, Mockito.never()).search(any(), any());
}

@Test
public void testRunWithInvalidQuery() {
String inputString = "{\"index\": \"test-index\", \"query\": \"invalid query\"}";
Map<String, String> parameters = Map.of("input", inputString);
ActionListener<String> listener = mock(ActionListener.class);
mockedSearchIndexTool.run(parameters, listener);
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
Mockito.verify(client, Mockito.never()).search(any(), any());
}

@Test
public void testRunWithEmptyQueryBody() {
String inputString = "{\"index\": \"test-index\", \"query\": {}}";
Map<String, String> parameters = Map.of("input", inputString);
mockedSearchIndexTool.run(parameters, null);
Mockito.verify(client, times(1)).search(any(), any());
Mockito.verify(client, Mockito.never()).execute(any(), any(), any());
}

@Test
public void testFactory() {
SearchIndexTool searchIndexTool = SearchIndexTool.Factory.getInstance().create(Collections.emptyMap());
assertEquals(SearchIndexTool.TYPE, searchIndexTool.getType());
}
}
Loading

0 comments on commit ea7fefa

Please sign in to comment.