From 3dfd5ef0825ca13f14b1e9f92b654c52eed7c0c2 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 8 Apr 2022 13:46:33 -0700 Subject: [PATCH] add more test to improve coverage of abstract search action Signed-off-by: Yaliang Wu --- plugin/build.gradle | 1 - .../ml/rest/RestMLSearchModelActionTests.java | 150 +++++++++++++++++- .../org/opensearch/ml/utils/TestHelper.java | 14 +- 3 files changed, 162 insertions(+), 3 deletions(-) diff --git a/plugin/build.gradle b/plugin/build.gradle index 5e20abacdb..9cdb653c06 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -208,7 +208,6 @@ List jacocoExclusions = [ // TODO: add more unit test to meet the minimal test coverage. 'org.opensearch.ml.constant.CommonValue', 'org.opensearch.ml.plugin.MachineLearningPlugin*', - 'org.opensearch.ml.rest.AbstractMLSearchAction*', 'org.opensearch.ml.rest.RestMLExecuteAction' //0.3 ] diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java index 71f57fee61..51c49b446b 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java @@ -5,22 +5,104 @@ package org.opensearch.ml.rest; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX; +import static org.opensearch.ml.utils.TestHelper.getSearchAllRestRequest; + +import java.io.IOException; import java.util.List; +import org.apache.lucene.search.TotalHits; import org.hamcrest.Matchers; import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; public class RestMLSearchModelActionTests extends OpenSearchTestCase { private RestMLSearchModelAction restMLSearchModelAction; + NodeClient client; + private ThreadPool threadPool; + @Mock + RestChannel channel; + @Before - public void setup() { + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); restMLSearchModelAction = new RestMLSearchModelAction(); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + + doReturn(builder).when(channel).newBuilder(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + + String modelContent = "{\"name\":\"FIT_RCF\",\"algorithm\":\"FIT_RCF\",\"version\":1,\"content\":\"xxx\"}"; + SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { model }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + hits, + InternalAggregations.EMPTY, + null, + false, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + actionListener.onResponse(searchResponse); + return null; + }).when(client).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); } public void testConstructor() { @@ -43,4 +125,70 @@ public void testRoutes() { assertThat(postRoute.getMethod(), Matchers.either(Matchers.is(RestRequest.Method.POST)).or(Matchers.is(RestRequest.Method.GET))); assertEquals("/_plugins/_ml/models/_search", postRoute.getPath()); } + + public void testPrepareRequest() throws Exception { + RestRequest request = getSearchAllRestRequest(); + restMLSearchModelAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); + verify(channel, times(1)).sendResponse(responseCaptor.capture()); + SearchRequest searchRequest = argumentCaptor.getValue(); + String[] indices = searchRequest.indices(); + assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); + assertEquals( + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"ui_metadata\"]}}", + searchRequest.source().toString() + ); + RestResponse restResponse = responseCaptor.getValue(); + assertNotEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); + } + + public void testPrepareRequest_timeout() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + + SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + hits, + InternalAggregations.EMPTY, + null, + true, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + actionListener.onResponse(searchResponse); + return null; + }).when(client).execute(eq(MLModelSearchAction.INSTANCE), any(), any()); + + RestRequest request = getSearchAllRestRequest(); + restMLSearchModelAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(MLModelSearchAction.INSTANCE), argumentCaptor.capture(), any()); + verify(channel, times(1)).sendResponse(responseCaptor.capture()); + SearchRequest searchRequest = argumentCaptor.getValue(); + String[] indices = searchRequest.indices(); + assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); + assertEquals( + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"ui_metadata\"]}}", + searchRequest.source().toString() + ); + ; + RestResponse restResponse = responseCaptor.getValue(); + assertEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index dc0885ee48..2344ac53bc 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -163,6 +164,13 @@ public static RestRequest getKMeansRestRequest() { return request; } + public static RestRequest getSearchAllRestRequest() { + RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withContent(new BytesArray(TestData.matchAllSearchQuery()), XContentType.JSON) + .build(); + return request; + } + public static void verifyParsedKMeansMLInput(MLInput mlInput) { assertEquals(FunctionName.KMEANS, mlInput.getAlgorithm()); assertEquals(MLInputDataType.SEARCH_QUERY, mlInput.getInputDataset().getInputDataType()); @@ -174,6 +182,10 @@ public static void verifyParsedKMeansMLInput(MLInput mlInput) { } private static NamedXContentRegistry getXContentRegistry() { - return new NamedXContentRegistry(Collections.singletonList(KMeansParams.XCONTENT_REGISTRY)); + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + List entries = new ArrayList<>(); + entries.addAll(searchModule.getNamedXContents()); + entries.add(KMeansParams.XCONTENT_REGISTRY); + return new NamedXContentRegistry(entries); } }