diff --git a/plugin/build.gradle b/plugin/build.gradle index 99ea12cb65..dad5d3470f 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -206,7 +206,6 @@ jacocoTestReport { List jacocoExclusions = [ // TODO: add more unit test to meet the minimal test coverage. - 'org.opensearch.ml.action.handler.*', 'org.opensearch.ml.constant.CommonValue', 'org.opensearch.ml.plugin.*', 'org.opensearch.ml.task.MLPredictTaskRunner', diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 35cc12218e..f3cd4f2a5e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -84,16 +84,10 @@ public static ActionListener wrapRestActionListener(ActionListener act } public static boolean isProperExceptionToReturn(Throwable e) { - if (e == null) { - return false; - } return e instanceof OpenSearchStatusException || e instanceof IndexNotFoundException || e instanceof InvalidIndexNameException; } public static boolean isBadRequest(Throwable e) { - if (e == null) { - return false; - } return e instanceof IllegalArgumentException || e instanceof MLResourceNotFoundException; } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index 3456e1a78b..a8417341b9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -6,7 +6,10 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java index c617a36058..0fae7a9e52 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -6,7 +6,10 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 064c449f25..589512cf41 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -5,20 +5,31 @@ package org.opensearch.ml.action.models; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.action.handler.MLSearchHandler; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.rest.RestStatus; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class SearchModelTransportActionTests extends OpenSearchTestCase { @@ -40,18 +51,95 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; + @Mock + ThreadPool threadPool; + MLSearchHandler mlSearchHandler; SearchModelTransportAction searchModelTransportAction; + ThreadContext threadContext; @Before public void setup() { MockitoAnnotations.openMocks(this); mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry)); searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, mlSearchHandler); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); } public void test_DoExecute() { searchModelTransportAction.doExecute(null, searchRequest, actionListener); verify(mlSearchHandler).search(searchRequest, actionListener); + verify(client).search(any(), any()); + } + + public void test_IndexNotFoundException() { + setupSearchMocks(new IndexNotFoundException("index not found")); + + searchModelTransportAction.doExecute(null, searchRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, actionListener); + verify(client).search(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(IndexNotFoundException.class, argumentCaptor.getValue().getClass()); + } + + public void test_IllegalArgumentException() { + setupSearchMocks(new IllegalArgumentException("illegal arguments")); + + searchModelTransportAction.doExecute(null, searchRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, actionListener); + verify(client).search(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(OpenSearchStatusException.class, argumentCaptor.getValue().getClass()); + } + + public void test_OpenSearchStatusException() { + setupSearchMocks(new OpenSearchStatusException("test error", RestStatus.CONFLICT, "args")); + + searchModelTransportAction.doExecute(null, searchRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, actionListener); + verify(client).search(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(OpenSearchStatusException.class, argumentCaptor.getValue().getClass()); + } + + public void test_CauseByMLException() { + Exception exception = new Exception(); + exception.initCause(new MLException("ml exception")); + setupSearchMocks(exception); + + searchModelTransportAction.doExecute(null, searchRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, actionListener); + verify(client).search(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(OpenSearchStatusException.class, argumentCaptor.getValue().getClass()); + } + + public void test_CauseByInvalidIndexNameException() { + Exception exception = new Exception(); + exception.initCause(new IndexNotFoundException("Index not Found")); + setupSearchMocks(exception); + + searchModelTransportAction.doExecute(null, searchRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, actionListener); + verify(client).search(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals(IndexNotFoundException.class, argumentCaptor.getValue().getClass()); + } + + private void setupSearchMocks(Exception exception) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(exception); + return null; + }).when(client).search(any(), any()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java index 3d03435be7..d82876bc70 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java @@ -6,8 +6,10 @@ package org.opensearch.ml.action.tasks; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index f4b0f4508d..8408ede99f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -6,8 +6,10 @@ package org.opensearch.ml.action.tasks; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException;