Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Oct 11, 2023
1 parent 3e75c0a commit b7bb290
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public void validateConnectorAccess(Client client, String connectorId, ActionLis
getConnector(client, connectorId, ActionListener.wrap(connector -> {
boolean hasPermission = hasPermission(user, connector);
wrappedListener.onResponse(hasPermission);
}, e -> { wrappedListener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); }));
}, e -> { wrappedListener.onFailure(e); }));
} catch (Exception e) {
log.error("Failed to validate Access for connector:" + connectorId, e);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -180,7 +181,7 @@ public void testDeleteConnector_BlockedByModel() throws IOException {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"1 models are still using this connector, please delete or update the models first!",
"1 models are still using this connector, please delete or update the models first: [model_ID]",
argumentCaptor.getValue().getMessage()
);
}
Expand Down Expand Up @@ -291,8 +292,17 @@ private SearchResponse getEmptySearchResponse() {
return searchResponse;
}

private SearchResponse getNonEmptySearchResponse() {
private SearchResponse getNonEmptySearchResponse() throws IOException {
SearchHit[] hits = new SearchHit[1];
String modelContent = "{\n"
+ " \"created_time\": 1684981986069,\n"
+ " \"last_updated_time\": 1684981986069,\n"
+ " \"_id\": \"model_ID\",\n"
+ " \"name\": \"test_model\",\n"
+ " \"description\": \"This is an example description\"\n"
+ " }";
SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent));
hits[0] = model;
SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f);
SearchResponseSections searchSections = new SearchResponseSections(
searchHits,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
Expand All @@ -41,9 +41,14 @@
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.TestHelper;
Expand All @@ -56,6 +61,7 @@
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public class TransportUpdateConnectorActionTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -148,6 +154,9 @@ public void setup() throws IOException {
SearchResponse.Clusters.EMPTY
);

Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);

transportUpdateConnectorAction = new UpdateConnectorTransportAction(
transportService,
actionFilters,
Expand All @@ -162,15 +171,40 @@ public void setup() throws IOException {
when(mlModelManager.getAllModelIds()).thenReturn(new String[] {});
shardId = new ShardId(new Index("indexName", "uuid"), 1);
updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED);

doAnswer(invocation -> {
ActionListener<Connector> listener = invocation.getArgument(2);
Connector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(ImmutableMap.of("api_key", "credential_value"))
.parameters(ImmutableMap.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();
// Connector connector = mock(HttpConnector.class);
// doNothing().when(connector).update(any(), any());
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));
}

@Test
public void test_execute_connectorAccessControl_success() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
Expand All @@ -190,11 +224,7 @@ public void test_execute_connectorAccessControl_success() {

@Test
public void test_execute_connectorAccessControl_NoPermission() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(false);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
doReturn(false).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand All @@ -207,11 +237,9 @@ public void test_execute_connectorAccessControl_NoPermission() {

@Test
public void test_execute_connectorAccessControl_AccessError() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException("Connector Access Control Error"));
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
doThrow(new RuntimeException("Connector Access Control Error"))
.when(connectorAccessControlHelper)
.validateConnectorAccess(any(Client.class), any(Connector.class));

transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand All @@ -223,7 +251,7 @@ public void test_execute_connectorAccessControl_AccessError() {
public void test_execute_connectorAccessControl_Exception() {
doThrow(new RuntimeException("exception in access control"))
.when(connectorAccessControlHelper)
.validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
.validateConnectorAccess(any(Client.class), any(Connector.class));

transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand All @@ -233,11 +261,7 @@ public void test_execute_connectorAccessControl_Exception() {

@Test
public void test_execute_UpdateWrongStatus() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
Expand All @@ -258,11 +282,7 @@ public void test_execute_UpdateWrongStatus() {

@Test
public void test_execute_UpdateException() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
Expand All @@ -284,11 +304,7 @@ public void test_execute_UpdateException() {

@Test
public void test_execute_SearchResponseNotEmpty() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
Expand All @@ -299,16 +315,14 @@ public void test_execute_SearchResponseNotEmpty() {
transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage());
assertTrue(
argumentCaptor.getValue().getMessage().contains("1 models are still using this connector, please undeploy the models first")
);
}

@Test
public void test_execute_SearchResponseError() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
Expand All @@ -324,11 +338,36 @@ public void test_execute_SearchResponseError() {

@Test
public void test_execute_SearchIndexNotFoundError() {
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
ActionListener<Connector> listener = invocation.getArgument(2);
Connector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(ImmutableMap.of("api_key", "credential_value"))
.parameters(ImmutableMap.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();
// Connector connector = mock(HttpConnector.class);
// doNothing().when(connector).update(any(), any());
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class));
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
Expand Down

0 comments on commit b7bb290

Please sign in to comment.