Skip to content

Commit

Permalink
add more UT coverage for model/task transport actions (opensearch-pro…
Browse files Browse the repository at this point in the history
…ject#268)

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored and ylwu-amzn committed Apr 22, 2022
1 parent 415a171 commit f200384
Show file tree
Hide file tree
Showing 8 changed files with 558 additions and 1 deletion.
2 changes: 1 addition & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ jacocoTestReport {

List<String> jacocoExclusions = [
// TODO: add more unit test to meet the minimal test coverage.
'org.opensearch.ml.action.*',
'org.opensearch.ml.action.handler.*',
'org.opensearch.ml.constant.CommonValue',
'org.opensearch.ml.indices.MLInputDatasetHandler',
'org.opensearch.ml.plugin.*',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.models;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import java.io.IOException;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class DeleteModelTransportActionTests extends OpenSearchTestCase {
@Mock
ThreadPool threadPool;

@Mock
Client client;

@Mock
TransportService transportService;

@Mock
ActionFilters actionFilters;

@Mock
ActionListener<DeleteResponse> actionListener;

@Mock
DeleteResponse deleteResponse;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

DeleteModelTransportAction deleteModelTransportAction;
MLModelDeleteRequest mlModelDeleteRequest;
ThreadContext threadContext;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);

mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId("test_id").build();
deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client));

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testDeleteModel_Success() {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModel_RuntimeException() {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("errorMessage"));
return null;
}).when(client).delete(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
}

public void testDeleteModel_ThreadContextError() {
when(threadPool.getThreadContext()).thenThrow(new RuntimeException("thread context error"));
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("thread context error", argumentCaptor.getValue().getMessage());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.models;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.ml.action.MLCommonsIntegTestCase;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.test.OpenSearchIntegTestCase;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 2)
public class GetModelITTests extends MLCommonsIntegTestCase {
private String irisIndexName;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Before
public void setUp() throws Exception {
super.setUp();
irisIndexName = "iris_data_for_model_it";
loadIrisData(irisIndexName);
}

public void testGetModel_IndexNotFound() {
exceptionRule.expect(MLResourceNotFoundException.class);
MLModel model = getModel("test_id");
}

public void testGetModel_NullModelIdException() {
exceptionRule.expect(ActionRequestValidationException.class);
MLModel model = getModel(null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.models;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import java.io.IOException;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class GetModelTransportActionTests extends OpenSearchTestCase {
@Mock
ThreadPool threadPool;

@Mock
Client client;

@Mock
NamedXContentRegistry xContentRegistry;

@Mock
TransportService transportService;

@Mock
ActionFilters actionFilters;

@Mock
ActionListener<MLModelGetResponse> actionListener;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

GetModelTransportAction getModelTransportAction;
MLModelGetRequest mlModelGetRequest;
ThreadContext threadContext;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").build();

getModelTransportAction = spy(new GetModelTransportAction(transportService, actionFilters, client, xContentRegistry));

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testGetModel_NullResponse() {
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(null);
return null;
}).when(client).get(any(), any());
getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Fail to find model", argumentCaptor.getValue().getMessage());
}

public void testGetModel_RuntimeException() {
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("errorMessage"));
return null;
}).when(client).get(any(), any());
getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.models;

import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.TransportService;

public class SearchModelTransportActionTests extends OpenSearchTestCase {
@Mock
Client client;

@Mock
NamedXContentRegistry namedXContentRegistry;

@Mock
TransportService transportService;

@Mock
ActionFilters actionFilters;

@Mock
SearchRequest searchRequest;

@Mock
ActionListener<SearchResponse> actionListener;

MLSearchHandler mlSearchHandler;
SearchModelTransportAction searchModelTransportAction;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry));
searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, mlSearchHandler);
}

public void test_DoExecute() {
searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
}
}
Loading

0 comments on commit f200384

Please sign in to comment.