From f200384888a28a8fa1b298c789400d5d9a176b0b Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 31 Mar 2022 13:44:49 -0700 Subject: [PATCH] add more UT coverage for model/task transport actions (#268) Signed-off-by: Xun Zhang --- plugin/build.gradle | 2 +- .../DeleteModelTransportActionTests.java | 100 ++++++++++++++++ .../ml/action/models/GetModelITTests.java | 40 +++++++ .../models/GetModelTransportActionTests.java | 94 +++++++++++++++ .../SearchModelTransportActionTests.java | 57 +++++++++ .../tasks/DeleteTaskTransportActionTests.java | 101 ++++++++++++++++ .../tasks/GetTaskTransportActionTests.java | 108 ++++++++++++++++++ .../tasks/SearchTaskTransportActionTests.java | 57 +++++++++ 8 files changed, 558 insertions(+), 1 deletion(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java diff --git a/plugin/build.gradle b/plugin/build.gradle index 79a4c6191c..9f218635af 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -204,7 +204,7 @@ jacocoTestReport { List jacocoExclusions = [ // TODO: add more unit test to meet the minimal test coverage. - 'org.opensearch.ml.action.*', + 'org.opensearch.ml.action.handler.*', 'org.opensearch.ml.constant.CommonValue', 'org.opensearch.ml.indices.MLInputDatasetHandler', 'org.opensearch.ml.plugin.*', diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java new file mode 100644 index 0000000000..3456e1a78b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class DeleteModelTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + DeleteResponse deleteResponse; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + DeleteModelTransportAction deleteModelTransportAction; + MLModelDeleteRequest mlModelDeleteRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId("test_id").build(); + deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client)); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testDeleteModel_Success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + public void testDeleteModel_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).delete(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public void testDeleteModel_ThreadContextError() { + when(threadPool.getThreadContext()).thenThrow(new RuntimeException("thread context error")); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("thread context error", argumentCaptor.getValue().getMessage()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java new file mode 100644 index 0000000000..d7320ce42f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ml.action.MLCommonsIntegTestCase; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.test.OpenSearchIntegTestCase; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 2) +public class GetModelITTests extends MLCommonsIntegTestCase { + private String irisIndexName; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + super.setUp(); + irisIndexName = "iris_data_for_model_it"; + loadIrisData(irisIndexName); + } + + public void testGetModel_IndexNotFound() { + exceptionRule.expect(MLResourceNotFoundException.class); + MLModel model = getModel("test_id"); + } + + public void testGetModel_NullModelIdException() { + exceptionRule.expect(ActionRequestValidationException.class); + MLModel model = getModel(null); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java new file mode 100644 index 0000000000..c617a36058 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.model.MLModelGetRequest; +import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetModelTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + GetModelTransportAction getModelTransportAction; + MLModelGetRequest mlModelGetRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").build(); + + getModelTransportAction = spy(new GetModelTransportAction(transportService, actionFilters, client, xContentRegistry)); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testGetModel_NullResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find model", argumentCaptor.getValue().getMessage()); + } + + public void testGetModel_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).get(any(), any()); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java new file mode 100644 index 0000000000..064c449f25 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.action.handler.MLSearchHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportService; + +public class SearchModelTransportActionTests extends OpenSearchTestCase { + @Mock + Client client; + + @Mock + NamedXContentRegistry namedXContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + SearchRequest searchRequest; + + @Mock + ActionListener actionListener; + + MLSearchHandler mlSearchHandler; + SearchModelTransportAction searchModelTransportAction; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry)); + searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, mlSearchHandler); + } + + public void test_DoExecute() { + searchModelTransportAction.doExecute(null, searchRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, actionListener); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java new file mode 100644 index 0000000000..3d03435be7 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/DeleteTaskTransportActionTests.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.tasks; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verify; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class DeleteTaskTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + DeleteResponse deleteResponse; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + DeleteTaskTransportAction deleteTaskTransportAction; + MLTaskDeleteRequest mlTaskDeleteRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId("test_id").build(); + deleteTaskTransportAction = spy(new DeleteTaskTransportAction(transportService, actionFilters, client)); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testDeleteModel_Success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteTaskTransportAction.doExecute(null, mlTaskDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + public void testDeleteModel_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).delete(any(), any()); + + deleteTaskTransportAction.doExecute(null, mlTaskDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public void testDeleteModel_ThreadContextError() { + when(threadPool.getThreadContext()).thenThrow(new RuntimeException("thread context error")); + deleteTaskTransportAction.doExecute(null, mlTaskDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("thread context error", argumentCaptor.getValue().getMessage()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java new file mode 100644 index 0000000000..f4b0f4508d --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.tasks; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verify; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.transport.task.MLTaskGetRequest; +import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetTaskTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + GetTaskTransportAction getTaskTransportAction; + MLTaskGetRequest mlTaskGetRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + mlTaskGetRequest = MLTaskGetRequest.builder().taskId("test_id").build(); + + getTaskTransportAction = spy(new GetTaskTransportAction(transportService, actionFilters, client, xContentRegistry)); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testGetTask_NullResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).get(any(), any()); + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public void testGetTask_IndexNotFoundException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("Index Not Found")); + return null; + }).when(client).get(any(), any()); + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java new file mode 100644 index 0000000000..34732bd17c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.tasks; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.action.handler.MLSearchHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportService; + +public class SearchTaskTransportActionTests extends OpenSearchTestCase { + @Mock + Client client; + + @Mock + NamedXContentRegistry namedXContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + SearchRequest searchRequest; + + @Mock + ActionListener actionListener; + + MLSearchHandler mlSearchHandler; + SearchTaskTransportAction searchTaskTransportAction; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry)); + searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, mlSearchHandler); + } + + public void test_DoExecute() { + searchTaskTransportAction.doExecute(null, searchRequest, actionListener); + verify(mlSearchHandler).search(searchRequest, actionListener); + } +}