Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor predictable: add method to check if model is ready #1057

Merged
merged 2 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public void writeTo(StreamOutput out) throws IOException {
if (headers != null) {
out.writeBoolean(true);
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(requestBody);
out.writeOptionalString(preProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,22 @@

package org.opensearch.ml.common.connector;

import java.util.Arrays;
import java.util.Set;

public class ConnectorProtocols {

public static final String HTTP = "http";
public static final String AWS_SIGV4 = "aws_sigv4";

public static final Set<String> VALID_PROTOCOLS = Set.of(HTTP, AWS_SIGV4);

public static void validateProtocol(String protocol) {
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]");
}
if (!VALID_PROTOCOLS.contains(protocol)) {
throw new IllegalArgumentException("Unsupported connector protocol. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

Expand All @@ -47,6 +48,7 @@ public class HttpConnector extends AbstractConnector {
public HttpConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
validateProtocol(protocol);
this.name = name;
this.description = description;
this.version = version;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,10 @@ public interface Predictable {
* Close resources like deployed model.
*/
void close();

/**
* Check if model ready to be used.
* @return
*/
boolean isModelReady();
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ public void close() {
}
}

@Override
public boolean isModelReady() {
if (predictors == null || modelHelper == null || modelId == null) {
return false;
}
return true;
}

public abstract Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig);

public abstract TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ public void close() {
this.libSVMAnomalyModel = null;
}

@Override
public boolean isModelReady() {
return libSVMAnomalyModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
MLInputDataset inputDataset = mlInput.getInputDataset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ public void close() {
this.kMeansModel = null;
}

@Override
public boolean isModelReady() {
return kMeansModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ public void close() {
this.summary = null;
}

@Override
public boolean isModelReady() {
return summary != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
Iterable<float[]> centroidsLst = Arrays.asList(summary.summaryPoints);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ public void close() {
forest = null;
}

@Override
public boolean isModelReady() {
return forest != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public void close() {
this.forest = null;
}

@Override
public boolean isModelReady() {
return forest != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ public void close() {
this.regressionModel = null;
}

@Override
public boolean isModelReady() {
return regressionModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
if (regressionModel == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ public void close() {
this.classificationModel = null;
}

@Override
public boolean isModelReady() {
return classificationModel != null;
}

@Override
public MLOutput predict(MLInput mlInput) {
DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -19,11 +15,9 @@
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.script.ScriptService;
Expand All @@ -32,27 +26,51 @@

@Log4j2
@Function(FunctionName.REMOTE)
public class RemoteModel extends DLModel {
public class RemoteModel implements Predictable {

public static final String CLUSTER_SERVICE = "cluster_service";
public static final String SCRIPT_SERVICE = "script_service";
public static final String CLIENT = "client";
public static final String XCONTENT_REGISTRY = "xcontent_registry";

private RemoteConnectorExecutor connectorExecutor;

@VisibleForTesting
RemoteConnectorExecutor getConnectorExecutor() {
return this.connectorExecutor;
}

@Override
public MLOutput predict(MLInput mlInput, MLModel model) {
throw new IllegalArgumentException(
"Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + model.getModelId() + "/_deploy"
);
}

@Override
public MLOutput predict(MLInput mlInput) {
if (!isModelReady()) {
throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy");
}
try {
return predict(modelId, mlInput);
} catch (Throwable t) {
log.error("Failed to call remote model", t);
throw new MLException("Failed to call remote model. " + t.getMessage());
return connectorExecutor.executePredict(mlInput);
} catch (RuntimeException e) {
log.error("Failed to call remote model", e);
throw e;
} catch (Throwable e) {
log.error("Failed to call remote model", e);
throw new MLException(e);
}
}

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
return connectorExecutor.executePredict(mlInput);
public void close() {
this.connectorExecutor = null;
}

@Override
public boolean isModelReady() {
return connectorExecutor != null;
}

@Override
Expand All @@ -65,21 +83,13 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
this.connectorExecutor.setClient((Client) params.get(CLIENT));
this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY));
} catch (Exception e) {
} catch (RuntimeException e) {
log.error("Failed to init remote model", e);
throw e;
} catch (Throwable e) {
log.error("Failed to init remote model", e);
throw new MLException(e);
}
}

@Override
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
return null;
}

@Override
public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
return null;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public void close() {
sampleParam = DEFAULT_SAMPLE_PARAM;
}

@Override
public boolean isModelReady() {
return true;
}

@Override
public MLOutput predict(MLInput mlInput) {
AtomicReference<Double> sum = new AtomicReference<>((double) 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
AwsConnector.awsConnectorBuilder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
AwsConnector.awsConnectorBuilder().name("test connector").protocol("http").version("1").actions(Arrays.asList(predictAction)).build();
}

@Test
Expand All @@ -99,7 +99,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

Expand All @@ -124,7 +124,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
}

Expand Down Expand Up @@ -104,7 +104,7 @@ public void processOutput_NoPostprocessFunction() throws IOException {
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
ModelTensors tensors = ConnectorUtils.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Expand All @@ -126,7 +126,7 @@ public void processOutput_PostprocessFunction() throws IOException {
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Expand All @@ -153,7 +153,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService);
Assert.assertNotNull(remoteInferenceInputDataSet.getParameters());
Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void invokeRemoteModel_WrongHttpMethod() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor.invokeRemoteModel(null, null, null, null);
}
Expand All @@ -79,7 +79,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
Expand All @@ -103,7 +103,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Expand All @@ -129,7 +129,7 @@ public void executePredict_TextDocsInput() throws IOException {
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
Expand Down
Loading