Skip to content

Commit

Permalink
add more UT for search transport handler (#272)
Browse files Browse the repository at this point in the history
* add more UT for search transport handler

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Apr 4, 2022
1 parent 3910316 commit 1e23a79
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 11 deletions.
1 change: 0 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ jacocoTestReport {

List<String> jacocoExclusions = [
// TODO: add more unit test to meet the minimal test coverage.
'org.opensearch.ml.action.handler.*',
'org.opensearch.ml.constant.CommonValue',
'org.opensearch.ml.plugin.*',
'org.opensearch.ml.task.MLPredictTaskRunner',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,10 @@ public static <T> ActionListener<T> wrapRestActionListener(ActionListener<T> act
}

public static boolean isProperExceptionToReturn(Throwable e) {
if (e == null) {
return false;
}
return e instanceof OpenSearchStatusException || e instanceof IndexNotFoundException || e instanceof InvalidIndexNameException;
}

public static boolean isBadRequest(Throwable e) {
if (e == null) {
return false;
}
return e instanceof IllegalArgumentException || e instanceof MLResourceNotFoundException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
package org.opensearch.ml.action.models;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
package org.opensearch.ml.action.models;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,31 @@

package org.opensearch.ml.action.models;

import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.rest.RestStatus;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class SearchModelTransportActionTests extends OpenSearchTestCase {
Expand All @@ -40,18 +51,95 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase {
@Mock
ActionListener<SearchResponse> actionListener;

@Mock
ThreadPool threadPool;

MLSearchHandler mlSearchHandler;
SearchModelTransportAction searchModelTransportAction;
ThreadContext threadContext;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry));
searchModelTransportAction = new SearchModelTransportAction(transportService, actionFilters, mlSearchHandler);

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void test_DoExecute() {
searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(any(), any());
}

public void test_IndexNotFoundException() {
setupSearchMocks(new IndexNotFoundException("index not found"));

searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(IndexNotFoundException.class, argumentCaptor.getValue().getClass());
}

public void test_IllegalArgumentException() {
setupSearchMocks(new IllegalArgumentException("illegal arguments"));

searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OpenSearchStatusException.class, argumentCaptor.getValue().getClass());
}

public void test_OpenSearchStatusException() {
setupSearchMocks(new OpenSearchStatusException("test error", RestStatus.CONFLICT, "args"));

searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OpenSearchStatusException.class, argumentCaptor.getValue().getClass());
}

public void test_CauseByMLException() {
Exception exception = new Exception();
exception.initCause(new MLException("ml exception"));
setupSearchMocks(exception);

searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OpenSearchStatusException.class, argumentCaptor.getValue().getClass());
}

public void test_CauseByInvalidIndexNameException() {
Exception exception = new Exception();
exception.initCause(new IndexNotFoundException("Index not Found"));
setupSearchMocks(exception);

searchModelTransportAction.doExecute(null, searchRequest, actionListener);
verify(mlSearchHandler).search(searchRequest, actionListener);
verify(client).search(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(IndexNotFoundException.class, argumentCaptor.getValue().getClass());
}

private void setupSearchMocks(Exception exception) {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(exception);
return null;
}).when(client).search(any(), any());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
package org.opensearch.ml.action.tasks;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
package org.opensearch.ml.action.tasks;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;

Expand Down

0 comments on commit 1e23a79

Please sign in to comment.