Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 6, 2024
1 parent a196399 commit 3edc9a7
Show file tree
Hide file tree
Showing 16 changed files with 688 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,24 @@ public class MLExecuteConnectorRequest extends MLTaskRequest {
String connectorId;
String connectorAction;
MLInput mlInput;
@Setter
User user;

@Builder
public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput, boolean dispatchTask, User user) {
public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.connectorAction = connectorAction == null ? "predict" : connectorAction;
this.connectorId = connectorId;
this.user = user;
}

public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput) {
this(connectorId, connectorAction, mlInput, true, null);
}

public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput, User user) {
this(connectorId, connectorAction, mlInput, true, user);
this(connectorId, connectorAction, mlInput, true);
}

public MLExecuteConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.connectorAction = in.readString();
this.mlInput = new MLInput(in);
if (in.readBoolean()) {
this.user = new User(in);
}
}

@Override
Expand All @@ -72,12 +62,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.connectorId);
out.writeString(this.connectorAction);
this.mlInput.writeTo(out);
if (user != null) {
out.writeBoolean(true);
user.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;

Expand Down Expand Up @@ -110,7 +111,7 @@ public void constructor_NoPredictAction() {
Assert.assertNotNull(connector);

connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals(null, connector.getSessionToken());
Expand Down Expand Up @@ -149,13 +150,13 @@ public void constructor() {

AwsConnector connector = createAwsConnector(parameters, credential, url);
connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());
Assert.assertEquals("test_service", connector.getServiceName());
Assert.assertEquals("us-west-2", connector.getRegion());
Assert.assertEquals("https://test.com/model1", connector.getPredictEndpoint(parameters));
Assert.assertEquals("https://test.com/model1", connector.getActionEndpoint(PREDICT.name(), parameters));
}

@Test
Expand All @@ -170,13 +171,13 @@ public void constructor_NoParameter() {
String url = "https://test.com";
AwsConnector connector = createAwsConnector(null, credential, url);
connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SERVICE", connector.getServiceName());
Assert.assertEquals("decrypted: ENCRYPTED: US-WEST-2", connector.getRegion());
Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null));
Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null));
}

@Test
Expand All @@ -201,7 +202,7 @@ private AwsConnector createAwsConnector() {
}

private AwsConnector createAwsConnector(Map<String, String> parameters, Map<String, String> credential, String url) {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
ConnectorAction.ActionType actionType = PREDICT;
String method = "POST";
Map<String, String> headers = new HashMap<>();
headers.put("api_key", "${credential.key}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import java.util.Optional;
import java.util.function.Function;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;

public class HttpConnectorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -118,7 +120,7 @@ public void cloneConnector() {
@Test
public void decrypt() {
HttpConnector connector = createHttpConnector();
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
Assert.assertEquals(1, decryptedCredential.size());
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
Expand Down Expand Up @@ -149,21 +151,21 @@ public void encrypted() {
@Test
public void getPredictEndpoint() {
HttpConnector connector = createHttpConnector();
Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null));
Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null));
}

@Test
public void getPredictHttpMethod() {
HttpConnector connector = createHttpConnector();
Assert.assertEquals("POST", connector.getPredictHttpMethod());
Assert.assertEquals("POST", connector.getActionHttpMethod(PREDICT.name()));
}

@Test
public void createPredictPayload_Invalid() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Some parameter placeholder not filled in payload: input");
HttpConnector connector = createHttpConnector();
String predictPayload = connector.createPredictPayload(null);
String predictPayload = connector.createPayload(PREDICT.name(), null);
connector.validatePayload(predictPayload);
}

Expand All @@ -173,7 +175,7 @@ public void createPredictPayload_InvalidJson() {
exceptionRule.expectMessage("Invalid payload: {\"input\": ${parameters.input} }");
String requestBody = "{\"input\": ${parameters.input} }";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
String predictPayload = connector.createPredictPayload(null);
String predictPayload = connector.createPayload(PREDICT.name(), null);
connector.validatePayload(predictPayload);
}

Expand All @@ -182,7 +184,7 @@ public void createPredictPayload() {
HttpConnector connector = createHttpConnector();
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "test input value");
String predictPayload = connector.createPredictPayload(parameters);
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert.assertEquals("{\"input\": \"test input value\"}", predictPayload);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.rmi.Remote;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;

public class MLExecuteConnectorRequestTests {
private MLExecuteConnectorRequest mlExecuteConnectorRequest;
private MLInput mlInput;
private String connectorId;
private String action;

@Before
public void setUp(){
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build();
connectorId = "test_connector";
action = "execute";
mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build();
mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).connectorAction(action).mlInput(mlInput).build();
}

@Test
public void writeToSuccess() throws IOException {
BytesStreamOutput output = new BytesStreamOutput();
mlExecuteConnectorRequest.writeTo(output);
MLExecuteConnectorRequest parsedRequest = new MLExecuteConnectorRequest(output.bytes().streamInput());
assertEquals(mlExecuteConnectorRequest.getConnectorId(), parsedRequest.getConnectorId());
assertEquals(mlExecuteConnectorRequest.getConnectorAction(), parsedRequest.getConnectorAction());
assertEquals(mlExecuteConnectorRequest.getMlInput().getAlgorithm(), parsedRequest.getMlInput().getAlgorithm());
assertEquals(mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), parsedRequest.getMlInput().getInputDataset().getInputDataType());
assertEquals("hello", ((RemoteInferenceInputDataSet)parsedRequest.getMlInput().getInputDataset()).getParameters().get("input"));
}

@Test
public void validateSuccess() {
assertNull(mlExecuteConnectorRequest.validate());
}

@Test
public void testConstructor() {
MLExecuteConnectorRequest executeConnectorRequest = new MLExecuteConnectorRequest(connectorId, action, mlInput);
assertTrue(executeConnectorRequest.isDispatchTask());
}

@Test
public void validateWithNullMLInputException() {
MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder()
.build();
ActionRequestValidationException exception = executeConnectorRequest.validate();
assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage());
}

@Test
public void validateWithNullMLInputDataSetException() {
MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput())
.build();
ActionRequestValidationException exception = executeConnectorRequest.validate();
assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage());
}

@Test
public void fromActionRequestWithMLExecuteConnectorRequestSuccess() {
assertSame(MLExecuteConnectorRequest.fromActionRequest(mlExecuteConnectorRequest), mlExecuteConnectorRequest);
}

@Test
public void fromActionRequestWithNonMLExecuteConnectorRequestSuccess() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlExecuteConnectorRequest.writeTo(out);
}
};
MLExecuteConnectorRequest result = MLExecuteConnectorRequest.fromActionRequest(actionRequest);
assertNotSame(result, mlExecuteConnectorRequest);
assertEquals(mlExecuteConnectorRequest.getConnectorId(), result.getConnectorId());
assertEquals(mlExecuteConnectorRequest.getConnectorAction(), result.getConnectorAction());
assertEquals(mlExecuteConnectorRequest.getMlInput().getFunctionName(), result.getMlInput().getFunctionName());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequestIOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLExecuteConnectorRequest.fromActionRequest(actionRequest);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ public static RemoteInferenceInputDataSet processInput(
if (mlInput == null) {
throw new IllegalArgumentException("Input is null");
}
Optional<ConnectorAction> predictAction = connector.findAction(action);
if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");
Optional<ConnectorAction> connectorAction = connector.findAction(action);
if (connectorAction.isEmpty()) {
throw new IllegalArgumentException("no " + action + " action found");
}
RemoteInferenceInputDataSet inputData = processMLInput(action, mlInput, connector, parameters, scriptService);
escapeRemoteInferenceInputData(inputData);
Expand Down Expand Up @@ -197,11 +197,11 @@ public static ModelTensors processOutput(
throw new IllegalArgumentException("guardrails triggered for LLM output");
}
List<ModelTensor> modelTensors = new ArrayList<>();
Optional<ConnectorAction> predictAction = connector.findAction(action);
if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");
Optional<ConnectorAction> optionalAction = connector.findAction(action);
if (optionalAction.isEmpty()) {
throw new IllegalArgumentException("no " + action + " action found");
}
ConnectorAction connectorAction = predictAction.get();
ConnectorAction connectorAction = optionalAction.get();
String postProcessFunction = connectorAction.getPostProcessFunction();
postProcessFunction = fillProcessFunctionParameter(parameters, postProcessFunction);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ private Tuple<Integer, Integer> calculateChunkSize(String action, TextDocsInputD
return Tuple.tuple(textDocsLength / stepSize + 1, stepSize);
}
} else {
Optional<ConnectorAction> predictAction = getConnector().findAction(action);
if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");
Optional<ConnectorAction> connectorAction = getConnector().findAction(action);
if (connectorAction.isEmpty()) {
throw new IllegalArgumentException("no " + action + " action found");
}
String preProcessFunction = predictAction.get().getPreProcessFunction();
String preProcessFunction = connectorAction.get().getPreProcessFunction();
if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) {
// user defined preprocess script, this case, the chunk size is always equals to text docs length.
return Tuple.tuple(textDocsLength, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.engine.tools;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;

import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -36,13 +38,13 @@ public class ConnectorTool implements Tool {
public static final String TYPE = "ConnectorTool";
public static final String CONNECTOR_ID = "connector_id";
public static final String CONNECTOR_ACTION = "connector_action";
private static final String DEFAULT_DESCRIPTION = "This tool will invoke external service.";

@Setter
@Getter
private String name = IndexMappingTool.TYPE;
private String name = ConnectorTool.TYPE;
@Getter
@Setter
private String description = DEFAULT_DESCRIPTION;
private String description = Factory.DEFAULT_DESCRIPTION;
@Getter
private String version;
@Setter
Expand All @@ -56,8 +58,11 @@ public class ConnectorTool implements Tool {

public ConnectorTool(Client client, String connectorId, String connectorAction) {
this.client = client;
if (connectorId == null) {
throw new IllegalArgumentException("connector_id can't be null");
}
this.connectorId = connectorId;
this.connectorAction = connectorAction;
this.connectorAction = connectorAction == null ? PREDICT.name() : connectorAction;

outputParser = new Parser() {
@Override
Expand Down Expand Up @@ -103,7 +108,7 @@ public boolean validate(Map<String, String> parameters) {

public static class Factory implements Tool.Factory<ConnectorTool> {
public static final String TYPE = "ConnectorTool";
private static final String DEFAULT_DESCRIPTION = "This tool will invoke external service.";
public static final String DEFAULT_DESCRIPTION = "This tool will invoke external service.";
private Client client;
private static Factory INSTANCE;

Expand Down
Loading

0 comments on commit 3edc9a7

Please sign in to comment.