Skip to content

Commit

Permalink
add connector tool
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jun 6, 2024
1 parent 12beac2 commit 126b3f2
Show file tree
Hide file tree
Showing 12 changed files with 966 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public enum FunctionName {
SPARSE_TOKENIZE,
TEXT_SIMILARITY,
QUESTION_ANSWERING,
AGENT;
AGENT,
CONNECTOR;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.opensearch.action.ActionType;
import org.opensearch.ml.common.transport.MLTaskResponse;

public class MLExecuteConnectorAction extends ActionType<MLTaskResponse> {
public static final MLExecuteConnectorAction INSTANCE = new MLExecuteConnectorAction();
public static final String NAME = "cluster:admin/opensearch/ml/connectors/execute";

private MLExecuteConnectorAction() {
super(NAME, MLTaskResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(level = AccessLevel.PRIVATE)
@ToString
public class MLExecuteConnectorRequest extends MLTaskRequest {

String connectorId;
String connectorAction;
MLInput mlInput;

@Builder
public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.connectorAction = connectorAction == null ? "predict" : connectorAction;
this.connectorId = connectorId;
}

public MLExecuteConnectorRequest(String connectorId, String connectorAction, MLInput mlInput) {
this(connectorId, connectorAction, mlInput, true);
}

public MLExecuteConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.connectorAction = in.readString();
this.mlInput = new MLInput(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
out.writeString(this.connectorAction);
this.mlInput.writeTo(out);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (this.mlInput == null) {
exception = addValidationError("ML input can't be null", exception);
} else if (this.mlInput.getInputDataset() == null) {
exception = addValidationError("input data can't be null", exception);
}

return exception;
}


public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLExecuteConnectorRequest) {
return (MLExecuteConnectorRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLExecuteConnectorRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLPredictionTaskRequest", e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;

public class MLExecuteConnectorRequestTests {
private MLExecuteConnectorRequest mlExecuteConnectorRequest;
private MLInput mlInput;
private String connectorId;
private String action;

@Before
public void setUp(){
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build();
connectorId = "test_connector";
action = "execute";
mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build();
mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).connectorAction(action).mlInput(mlInput).build();
}

@Test
public void writeToSuccess() throws IOException {
BytesStreamOutput output = new BytesStreamOutput();
mlExecuteConnectorRequest.writeTo(output);
MLExecuteConnectorRequest parsedRequest = new MLExecuteConnectorRequest(output.bytes().streamInput());
assertEquals(mlExecuteConnectorRequest.getConnectorId(), parsedRequest.getConnectorId());
assertEquals(mlExecuteConnectorRequest.getConnectorAction(), parsedRequest.getConnectorAction());
assertEquals(mlExecuteConnectorRequest.getMlInput().getAlgorithm(), parsedRequest.getMlInput().getAlgorithm());
assertEquals(mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), parsedRequest.getMlInput().getInputDataset().getInputDataType());
assertEquals("hello", ((RemoteInferenceInputDataSet)parsedRequest.getMlInput().getInputDataset()).getParameters().get("input"));
}

@Test
public void validateSuccess() {
assertNull(mlExecuteConnectorRequest.validate());
}

@Test
public void testConstructor() {
MLExecuteConnectorRequest executeConnectorRequest = new MLExecuteConnectorRequest(connectorId, action, mlInput);
assertTrue(executeConnectorRequest.isDispatchTask());
}

@Test
public void validateWithNullMLInputException() {
MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder()
.build();
ActionRequestValidationException exception = executeConnectorRequest.validate();
assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage());
}

@Test
public void validateWithNullMLInputDataSetException() {
MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput())
.build();
ActionRequestValidationException exception = executeConnectorRequest.validate();
assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage());
}

@Test
public void fromActionRequestWithMLExecuteConnectorRequestSuccess() {
assertSame(MLExecuteConnectorRequest.fromActionRequest(mlExecuteConnectorRequest), mlExecuteConnectorRequest);
}

@Test
public void fromActionRequestWithNonMLExecuteConnectorRequestSuccess() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlExecuteConnectorRequest.writeTo(out);
}
};
MLExecuteConnectorRequest result = MLExecuteConnectorRequest.fromActionRequest(actionRequest);
assertNotSame(result, mlExecuteConnectorRequest);
assertEquals(mlExecuteConnectorRequest.getConnectorId(), result.getConnectorId());
assertEquals(mlExecuteConnectorRequest.getConnectorAction(), result.getConnectorAction());
assertEquals(mlExecuteConnectorRequest.getMlInput().getFunctionName(), result.getMlInput().getFunctionName());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequestIOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLExecuteConnectorRequest.fromActionRequest(actionRequest);
}
}
Loading

0 comments on commit 126b3f2

Please sign in to comment.