-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add search index tool Signed-off-by: yuye-aws <[email protected]> * add search index tool Signed-off-by: yuye-aws <[email protected]> * spotless apply Signed-off-by: yuye-aws <[email protected]> * add integration test for search index tool Signed-off-by: yuye-aws <[email protected]> * fix unit test error Signed-off-by: yuye-aws <[email protected]> * run spotless apply Signed-off-by: yuye-aws <[email protected]> * change json file path for search index tool it Signed-off-by: yuye-aws <[email protected]> * fix integration test for search index tool Signed-off-by: yuye-aws <[email protected]> * run spotless apply Signed-off-by: yuye-aws <[email protected]> * skip multiple node integration test for search index tool Signed-off-by: yuye-aws <[email protected]> * resolve code duplication in search index tool it Signed-off-by: yuye-aws <[email protected]> * change method access Signed-off-by: yuye-aws <[email protected]> * remove unused class Signed-off-by: yuye-aws <[email protected]> * remove wild card import Signed-off-by: yuye-aws <[email protected]> * add comment in gradle file Signed-off-by: yuye-aws <[email protected]> --------- Signed-off-by: yuye-aws <[email protected]>
- Loading branch information
Showing
8 changed files
with
701 additions
and
0 deletions.
There are no files selected for viewing
199 changes: 199 additions & 0 deletions
199
ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.opensearch.ml.common.CommonValue.*; | ||
|
||
import java.io.IOException; | ||
import java.security.AccessController; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.xcontent.LoggingDeprecationHandler; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.spi.tools.Tool; | ||
import org.opensearch.ml.common.spi.tools.ToolAnnotation; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; | ||
import org.opensearch.ml.common.transport.model.MLModelSearchAction; | ||
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
|
||
import com.google.gson.JsonElement; | ||
import com.google.gson.JsonObject; | ||
|
||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Getter | ||
@Setter | ||
@Log4j2 | ||
@ToolAnnotation(SearchIndexTool.TYPE) | ||
public class SearchIndexTool implements Tool { | ||
|
||
public static final String INPUT_FIELD = "input"; | ||
public static final String INDEX_FIELD = "index"; | ||
public static final String QUERY_FIELD = "query"; | ||
|
||
public static final String TYPE = "SearchIndexTool"; | ||
private static final String DEFAULT_DESCRIPTION = | ||
"Use this tool to search an index by providing two parameters: 'index' for the index name, and 'query' for the OpenSearch DSL formatted query. Only use this tool when both index name and DSL query is available."; | ||
|
||
private String name = TYPE; | ||
|
||
private String description = DEFAULT_DESCRIPTION; | ||
|
||
private Client client; | ||
|
||
private NamedXContentRegistry xContentRegistry; | ||
|
||
public SearchIndexTool(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getVersion() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public boolean validate(Map<String, String> parameters) { | ||
return parameters != null && parameters.containsKey(INPUT_FIELD) && parameters.get(INPUT_FIELD) != null; | ||
} | ||
|
||
private SearchRequest getSearchRequest(String index, String query) throws IOException { | ||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); | ||
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); | ||
searchSourceBuilder.parseXContent(queryParser); | ||
return new SearchRequest().source(searchSourceBuilder).indices(index); | ||
} | ||
|
||
private static Map<String, Object> processResponse(SearchHit hit) { | ||
Map<String, Object> docContent = new HashMap<>(); | ||
docContent.put("_index", hit.getIndex()); | ||
docContent.put("_id", hit.getId()); | ||
docContent.put("_score", hit.getScore()); | ||
docContent.put("_source", hit.getSourceAsMap()); | ||
return docContent; | ||
} | ||
|
||
@Override | ||
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) { | ||
try { | ||
String input = parameters.get(INPUT_FIELD); | ||
JsonObject jsonObject = StringUtils.gson.fromJson(input, JsonObject.class); | ||
String index = Optional.ofNullable(jsonObject).map(x -> x.get(INDEX_FIELD)).map(JsonElement::getAsString).orElse(null); | ||
String query = Optional.ofNullable(jsonObject).map(x -> x.get(QUERY_FIELD)).map(JsonElement::toString).orElse(null); | ||
if (index == null || query == null) { | ||
listener.onFailure(new IllegalArgumentException("SearchIndexTool's two parameter: index and query are required!")); | ||
return; | ||
} | ||
SearchRequest searchRequest = getSearchRequest(index, query); | ||
|
||
ActionListener<SearchResponse> actionListener = ActionListener.<SearchResponse>wrap(r -> { | ||
SearchHit[] hits = r.getHits().getHits(); | ||
|
||
if (hits != null && hits.length > 0) { | ||
StringBuilder contextBuilder = new StringBuilder(); | ||
for (SearchHit hit : hits) { | ||
String doc = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> { | ||
Map<String, Object> docContent = processResponse(hit); | ||
return StringUtils.gson.toJson(docContent); | ||
}); | ||
contextBuilder.append(doc).append("\n"); | ||
} | ||
listener.onResponse((T) contextBuilder.toString()); | ||
} else { | ||
listener.onResponse((T) ""); | ||
} | ||
}, e -> { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
}); | ||
|
||
// since searching connector and model needs access control, we need | ||
// to forward the request corresponding transport action | ||
if (Objects.equals(index, ML_CONNECTOR_INDEX)) { | ||
client.execute(MLConnectorSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else if (Objects.equals(index, ML_MODEL_INDEX)) { | ||
client.execute(MLModelSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else if (Objects.equals(index, ML_MODEL_GROUP_INDEX)) { | ||
client.execute(MLModelGroupSearchAction.INSTANCE, searchRequest, actionListener); | ||
} else { | ||
client.search(searchRequest, actionListener); | ||
} | ||
} catch (Exception e) { | ||
log.error("Failed to search index", e); | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
public static class Factory implements Tool.Factory<SearchIndexTool> { | ||
|
||
private Client client; | ||
private static Factory INSTANCE; | ||
|
||
private NamedXContentRegistry xContentRegistry; | ||
|
||
/** | ||
* Create or return the singleton factory instance | ||
*/ | ||
public static Factory getInstance() { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
synchronized (SearchIndexTool.class) { | ||
if (INSTANCE != null) { | ||
return INSTANCE; | ||
} | ||
INSTANCE = new Factory(); | ||
return INSTANCE; | ||
} | ||
} | ||
|
||
public void init(Client client, NamedXContentRegistry xContentRegistry) { | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
} | ||
|
||
@Override | ||
public SearchIndexTool create(Map<String, Object> params) { | ||
return new SearchIndexTool(client, xContentRegistry); | ||
} | ||
|
||
@Override | ||
public String getDefaultDescription() { | ||
return DEFAULT_DESCRIPTION; | ||
} | ||
|
||
@Override | ||
public String getDefaultType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public String getDefaultVersion() { | ||
return null; | ||
} | ||
} | ||
} |
188 changes: 188 additions & 0 deletions
188
ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.tools; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertFalse; | ||
import static org.junit.Assert.assertTrue; | ||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.Mockito.doAnswer; | ||
import static org.mockito.Mockito.eq; | ||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.never; | ||
import static org.mockito.Mockito.times; | ||
|
||
import java.io.InputStream; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.mockito.Mockito; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.common.xcontent.json.JsonXContent; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.common.Strings; | ||
import org.opensearch.core.xcontent.DeprecationHandler; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; | ||
import org.opensearch.ml.common.transport.model.MLModelSearchAction; | ||
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; | ||
import org.opensearch.search.SearchModule; | ||
|
||
import lombok.SneakyThrows; | ||
|
||
public class SearchIndexToolTests { | ||
static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( | ||
new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() | ||
); | ||
|
||
private Client client; | ||
|
||
private SearchIndexTool mockedSearchIndexTool; | ||
|
||
private String mockedSearchResponseString; | ||
|
||
@Before | ||
@SneakyThrows | ||
public void setup() { | ||
client = mock(Client.class); | ||
mockedSearchIndexTool = mock( | ||
SearchIndexTool.class, | ||
Mockito.withSettings().useConstructor(client, TEST_XCONTENT_REGISTRY_FOR_QUERY).defaultAnswer(Mockito.CALLS_REAL_METHODS) | ||
); | ||
|
||
try (InputStream searchResponseIns = SearchIndexTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { | ||
if (searchResponseIns != null) { | ||
mockedSearchResponseString = new String(searchResponseIns.readAllBytes()); | ||
} | ||
} | ||
} | ||
|
||
@Test | ||
@SneakyThrows | ||
public void testGetType() { | ||
String type = mockedSearchIndexTool.getType(); | ||
assertFalse(Strings.isNullOrEmpty(type)); | ||
assertEquals("SearchIndexTool", type); | ||
} | ||
|
||
@Test | ||
@SneakyThrows | ||
public void testValidate() { | ||
Map<String, String> parameters = Map.of("input", "{}"); | ||
assertTrue(mockedSearchIndexTool.validate(parameters)); | ||
} | ||
|
||
@Test | ||
@SneakyThrows | ||
public void testValidateWithEmptyInput() { | ||
Map<String, String> parameters = Map.of(); | ||
assertFalse(mockedSearchIndexTool.validate(parameters)); | ||
} | ||
|
||
@Test | ||
public void testRunWithNormalIndex() { | ||
String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
mockedSearchIndexTool.run(parameters, null); | ||
Mockito.verify(client, times(1)).search(any(), any()); | ||
Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); | ||
} | ||
|
||
@Test | ||
public void testRunWithConnectorIndex() { | ||
String inputString = "{\"index\": \".plugins-ml-connector\", \"query\": {\"query\": {\"match_all\": {}}}}"; | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
mockedSearchIndexTool.run(parameters, null); | ||
Mockito.verify(client, never()).search(any(), any()); | ||
Mockito.verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any()); | ||
} | ||
|
||
@Test | ||
public void testRunWithModelIndex() { | ||
String inputString = "{\"index\": \".plugins-ml-model\", \"query\": {\"query\": {\"match_all\": {}}}}"; | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
mockedSearchIndexTool.run(parameters, null); | ||
Mockito.verify(client, never()).search(any(), any()); | ||
Mockito.verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); | ||
} | ||
|
||
@Test | ||
public void testRunWithModelGroupIndex() { | ||
String inputString = "{\"index\": \".plugins-ml-model-group\", \"query\": {\"query\": {\"match_all\": {}}}}"; | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
mockedSearchIndexTool.run(parameters, null); | ||
Mockito.verify(client, never()).search(any(), any()); | ||
Mockito.verify(client, times(1)).execute(eq(MLModelGroupSearchAction.INSTANCE), any(), any()); | ||
} | ||
|
||
@Test | ||
@SneakyThrows | ||
public void testRunWithSearchResults() { | ||
SearchResponse mockedSearchResponse = SearchResponse | ||
.fromXContent( | ||
JsonXContent.jsonXContent | ||
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) | ||
); | ||
doAnswer(invocation -> { | ||
ActionListener<SearchResponse> listener = invocation.getArgument(1); | ||
listener.onResponse(mockedSearchResponse); | ||
return null; | ||
}).when(client).search(any(), any()); | ||
|
||
String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; | ||
final CompletableFuture<String> future = new CompletableFuture<>(); | ||
ActionListener<String> listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
mockedSearchIndexTool.run(parameters, listener); | ||
|
||
future.join(); | ||
|
||
Mockito.verify(client, times(1)).search(any(), any()); | ||
Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); | ||
} | ||
|
||
@Test | ||
@SneakyThrows | ||
public void testRunWithEmptyQuery() { | ||
String inputString = "{\"index\": \"test_index\"}"; | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
ActionListener<String> listener = mock(ActionListener.class); | ||
mockedSearchIndexTool.run(parameters, listener); | ||
Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); | ||
Mockito.verify(client, Mockito.never()).search(any(), any()); | ||
} | ||
|
||
@Test | ||
public void testRunWithInvalidQuery() { | ||
String inputString = "{\"index\": \"test-index\", \"query\": \"invalid query\"}"; | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
ActionListener<String> listener = mock(ActionListener.class); | ||
mockedSearchIndexTool.run(parameters, listener); | ||
Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); | ||
Mockito.verify(client, Mockito.never()).search(any(), any()); | ||
} | ||
|
||
@Test | ||
public void testRunWithEmptyQueryBody() { | ||
String inputString = "{\"index\": \"test-index\", \"query\": {}}"; | ||
Map<String, String> parameters = Map.of("input", inputString); | ||
mockedSearchIndexTool.run(parameters, null); | ||
Mockito.verify(client, times(1)).search(any(), any()); | ||
Mockito.verify(client, Mockito.never()).execute(any(), any(), any()); | ||
} | ||
|
||
@Test | ||
public void testFactory() { | ||
SearchIndexTool searchIndexTool = SearchIndexTool.Factory.getInstance().create(Collections.emptyMap()); | ||
assertEquals(SearchIndexTool.TYPE, searchIndexTool.getType()); | ||
} | ||
} |
Oops, something went wrong.