Skip to content

Commit

Permalink
refactor predictable: add method to check if model is ready (#1057)
Browse files Browse the repository at this point in the history
* refactor predictable: add method to check if model is ready

Signed-off-by: Yaliang Wu <[email protected]>

* fix failed ut

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent a1d78b7 commit 2450e9a
Show file tree
Hide file tree
Showing 19 changed files with 265 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public void writeTo(StreamOutput out) throws IOException {
if (headers != null) {
out.writeBoolean(true);
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(requestBody);
out.writeOptionalString(preProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,22 @@

package org.opensearch.ml.common.connector;

import java.util.Arrays;
import java.util.Set;

public class ConnectorProtocols {

public static final String HTTP = "http";
public static final String AWS_SIGV4 = "aws_sigv4";

public static final Set<String> VALID_PROTOCOLS = Set.of(HTTP, AWS_SIGV4);

public static void validateProtocol(String protocol) {
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]");
}
if (!VALID_PROTOCOLS.contains(protocol)) {
throw new IllegalArgumentException("Unsupported connector protocol. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

Expand All @@ -47,6 +48,7 @@ public class HttpConnector extends AbstractConnector {
public HttpConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
validateProtocol(protocol);
this.name = name;
this.description = description;
this.version = version;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,10 @@ public interface Predictable {
* Close resources like deployed model.
*/
void close();

/**
* Check if model ready to be used.
* @return
*/
boolean isModelReady();
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ public void close() {
}
}

@Override
public boolean isModelReady() {
if (predictors == null || modelHelper == null || modelId == null) {
return false;
}
return true;
}

public abstract Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig);

public abstract TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ public void close() {
this.libSVMAnomalyModel = null;
}

@Override
public boolean isModelReady() {
return libSVMAnomalyModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
MLInputDataset inputDataset = mlInput.getInputDataset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ public void close() {
this.kMeansModel = null;
}

@Override
public boolean isModelReady() {
return kMeansModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ public void close() {
this.summary = null;
}

@Override
public boolean isModelReady() {
return summary != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
Iterable<float[]> centroidsLst = Arrays.asList(summary.summaryPoints);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ public void close() {
forest = null;
}

@Override
public boolean isModelReady() {
return forest != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public void close() {
this.forest = null;
}

@Override
public boolean isModelReady() {
return forest != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ public void close() {
this.regressionModel = null;
}

@Override
public boolean isModelReady() {
return regressionModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
if (regressionModel == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ public void close() {
this.classificationModel = null;
}

@Override
public boolean isModelReady() {
return classificationModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -19,11 +15,9 @@
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.script.ScriptService;
Expand All @@ -32,27 +26,51 @@

@Log4j2
@Function(FunctionName.REMOTE)
public class RemoteModel extends DLModel {
public class RemoteModel implements Predictable {

public static final String CLUSTER_SERVICE = "cluster_service";
public static final String SCRIPT_SERVICE = "script_service";
public static final String CLIENT = "client";
public static final String XCONTENT_REGISTRY = "xcontent_registry";

private RemoteConnectorExecutor connectorExecutor;

@VisibleForTesting
RemoteConnectorExecutor getConnectorExecutor() {
return this.connectorExecutor;
}

@Override
public MLOutput predict(MLInput mlInput, MLModel model) {
throw new IllegalArgumentException(
"Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + model.getModelId() + "/_deploy"
);
}

@Override
public MLOutput predict(MLInput mlInput) {
if (!isModelReady()) {
throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy");
}
try {
return predict(modelId, mlInput);
} catch (Throwable t) {
log.error("Failed to call remote model", t);
throw new MLException("Failed to call remote model. " + t.getMessage());
return connectorExecutor.executePredict(mlInput);
} catch (RuntimeException e) {
log.error("Failed to call remote model", e);
throw e;
} catch (Throwable e) {
log.error("Failed to call remote model", e);
throw new MLException(e);
}
}

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
return connectorExecutor.executePredict(mlInput);
public void close() {
this.connectorExecutor = null;
}

@Override
public boolean isModelReady() {
return connectorExecutor != null;
}

@Override
Expand All @@ -65,21 +83,13 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
this.connectorExecutor.setClient((Client) params.get(CLIENT));
this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY));
} catch (Exception e) {
} catch (RuntimeException e) {
log.error("Failed to init remote model", e);
throw e;
} catch (Throwable e) {
log.error("Failed to init remote model", e);
throw new MLException(e);
}
}

@Override
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
return null;
}

@Override
public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
return null;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public void close() {
sampleParam = DEFAULT_SAMPLE_PARAM;
}

@Override
public boolean isModelReady() {
return true;
}

@Override
public MLOutput predict(MLInput mlInput) {
AtomicReference<Double> sum = new AtomicReference<>((double) 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
AwsConnector.awsConnectorBuilder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
AwsConnector.awsConnectorBuilder().name("test connector").protocol("http").version("1").actions(Arrays.asList(predictAction)).build();
}

@Test
Expand All @@ -99,7 +99,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

Expand All @@ -124,7 +124,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
}

Expand Down Expand Up @@ -104,7 +104,7 @@ public void processOutput_NoPostprocessFunction() throws IOException {
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
ModelTensors tensors = ConnectorUtils.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Expand All @@ -126,7 +126,7 @@ public void processOutput_PostprocessFunction() throws IOException {
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Expand All @@ -153,7 +153,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
Assert.assertNotNull(remoteInferenceInputDataSet.getParameters());
Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void invokeRemoteModel_WrongHttpMethod() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}
Expand All @@ -79,7 +79,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
Expand All @@ -103,7 +103,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Expand All @@ -129,7 +129,7 @@ public void executePredict_TextDocsInput() throws IOException {
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
Expand Down
Loading

0 comments on commit 2450e9a

Please sign in to comment.