From 2450e9a6e8989802615e1ada9294272d6d61f4d3 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sun, 9 Jul 2023 14:18:01 -0700 Subject: [PATCH] refactor predictable: add method to check if model is ready (#1057) * refactor predictable: add method to check if model is ready Signed-off-by: Yaliang Wu * fix failed ut Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../ml/common/connector/ConnectorAction.java | 2 + .../common/connector/ConnectorProtocols.java | 14 ++ .../ml/common/connector/HttpConnector.java | 2 + .../org/opensearch/ml/engine/Predictable.java | 6 + .../ml/engine/algorithms/DLModel.java | 8 ++ .../algorithms/ad/AnomalyDetectionLibSVM.java | 5 + .../engine/algorithms/clustering/KMeans.java | 5 + .../algorithms/clustering/RCFSummarize.java | 5 + .../algorithms/rcf/BatchRandomCutForest.java | 5 + .../rcf/FixedInTimeRandomCutForest.java | 5 + .../regression/LinearRegression.java | 5 + .../regression/LogisticRegression.java | 5 + .../engine/algorithms/remote/RemoteModel.java | 64 +++++---- .../engine/algorithms/sample/SampleAlgo.java | 5 + .../remote/AwsConnectorExecutorTest.java | 6 +- .../algorithms/remote/ConnectorUtilsTest.java | 8 +- .../remote/HttpJsonConnectorExecutorTest.java | 8 +- .../algorithms/remote/RemoteModelTest.java | 132 ++++++++++++++++++ .../ml/task/MLPredictTaskRunner.java | 17 ++- 19 files changed, 265 insertions(+), 42 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 926e9d02fc..ae43c10867 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -89,6 +89,8 @@ public void writeTo(StreamOutput out) throws IOException { if (headers != null) { out.writeBoolean(true); out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); } out.writeOptionalString(requestBody); out.writeOptionalString(preProcessFunction); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java index 0cb7785737..3d92d15bb7 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java @@ -5,8 +5,22 @@ package org.opensearch.ml.common.connector; +import java.util.Arrays; +import java.util.Set; + public class ConnectorProtocols { public static final String HTTP = "http"; public static final String AWS_SIGV4 = "aws_sigv4"; + + public static final Set VALID_PROTOCOLS = Set.of(HTTP, AWS_SIGV4); + + public static void validateProtocol(String protocol) { + if (protocol == null) { + throw new IllegalArgumentException("Connector protocol is null. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]"); + } + if (!VALID_PROTOCOLS.contains(protocol)) { + throw new IllegalArgumentException("Unsupported connector protocol. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]"); + } + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ef63895daf..fec51636b1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -28,6 +28,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; +import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.isJson; @@ -47,6 +48,7 @@ public class HttpConnector extends AbstractConnector { public HttpConnector(String name, String description, String version, String protocol, Map parameters, Map credential, List actions, List backendRoles, AccessMode accessMode) { + validateProtocol(protocol); this.name = name; this.description = description; this.version = version; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java index 4f1823225f..76bf159d18 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java @@ -45,4 +45,10 @@ public interface Predictable { * Close resources like deployed model. */ void close(); + + /** + * Check if model ready to be used. + * @return + */ + boolean isModelReady(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java index f0c763c12a..dbad1d982b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java @@ -157,6 +157,14 @@ public void close() { } } + @Override + public boolean isModelReady() { + if (predictors == null || modelHelper == null || modelId == null) { + return false; + } + return true; + } + public abstract Translator getTranslator(String engine, MLModelConfig modelConfig); public abstract TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java index 9b77abe05f..ed83e449dd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java @@ -84,6 +84,11 @@ public void close() { this.libSVMAnomalyModel = null; } + @Override + public boolean isModelReady() { + return libSVMAnomalyModel != null; + } + @Override public MLOutput predict(MLInput mlInput) { MLInputDataset inputDataset = mlInput.getInputDataset(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java index 4210b41fec..acbbf49076 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java @@ -98,6 +98,11 @@ public void close() { this.kMeansModel = null; } + @Override + public boolean isModelReady() { + return kMeansModel != null; + } + @Override public MLOutput predict(MLInput mlInput) { DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java index 5df9f5536e..7b9304daae 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java @@ -145,6 +145,11 @@ public void close() { this.summary = null; } + @Override + public boolean isModelReady() { + return summary != null; + } + @Override public MLOutput predict(MLInput mlInput) { Iterable centroidsLst = Arrays.asList(summary.summaryPoints); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java index 512e380c5b..84fe1a7779 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java @@ -82,6 +82,11 @@ public void close() { forest = null; } + @Override + public boolean isModelReady() { + return forest != null; + } + @Override public MLOutput predict(MLInput mlInput) { DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java index d26bf98c7b..889a486b8e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java @@ -110,6 +110,11 @@ public void close() { this.forest = null; } + @Override + public boolean isModelReady() { + return forest != null; + } + @Override public MLOutput predict(MLInput mlInput) { DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java index 2ed97ad2e7..590e5634ed 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java @@ -209,6 +209,11 @@ public void close() { this.regressionModel = null; } + @Override + public boolean isModelReady() { + return regressionModel != null; + } + @Override public MLOutput predict(MLInput mlInput) { if (regressionModel == null) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java index 2e3767a0a9..087ec00240 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java @@ -202,6 +202,11 @@ public void close() { this.classificationModel = null; } + @Override + public boolean isModelReady() { + return classificationModel != null; + } + @Override public MLOutput predict(MLInput mlInput) { DataFrame dataFrame = ((DataFrameInputDataset)mlInput.getInputDataset()).getDataFrame(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 07a00dbff6..4449ee6996 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -5,11 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; -import ai.djl.translate.TranslatorFactory; +import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -19,11 +15,9 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.output.MLOutput; -import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.engine.MLEngineClassLoader; -import org.opensearch.ml.engine.algorithms.DLModel; +import org.opensearch.ml.engine.Predictable; import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.script.ScriptService; @@ -32,27 +26,51 @@ @Log4j2 @Function(FunctionName.REMOTE) -public class RemoteModel extends DLModel { +public class RemoteModel implements Predictable { public static final String CLUSTER_SERVICE = "cluster_service"; public static final String SCRIPT_SERVICE = "script_service"; public static final String CLIENT = "client"; public static final String XCONTENT_REGISTRY = "xcontent_registry"; + private RemoteConnectorExecutor connectorExecutor; + @VisibleForTesting + RemoteConnectorExecutor getConnectorExecutor() { + return this.connectorExecutor; + } + + @Override + public MLOutput predict(MLInput mlInput, MLModel model) { + throw new IllegalArgumentException( + "Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + model.getModelId() + "/_deploy" + ); + } + @Override public MLOutput predict(MLInput mlInput) { + if (!isModelReady()) { + throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models//_deploy"); + } try { - return predict(modelId, mlInput); - } catch (Throwable t) { - log.error("Failed to call remote model", t); - throw new MLException("Failed to call remote model. " + t.getMessage()); + return connectorExecutor.executePredict(mlInput); + } catch (RuntimeException e) { + log.error("Failed to call remote model", e); + throw e; + } catch (Throwable e) { + log.error("Failed to call remote model", e); + throw new MLException(e); } } @Override - public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { - return connectorExecutor.executePredict(mlInput); + public void close() { + this.connectorExecutor = null; + } + + @Override + public boolean isModelReady() { + return connectorExecutor != null; } @Override @@ -65,21 +83,13 @@ public void initModel(MLModel model, Map params, Encryptor encry this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); this.connectorExecutor.setClient((Client) params.get(CLIENT)); this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); - } catch (Exception e) { + } catch (RuntimeException e) { + log.error("Failed to init remote model", e); + throw e; + } catch (Throwable e) { log.error("Failed to init remote model", e); throw new MLException(e); } } - @Override - public Translator getTranslator(String engine, MLModelConfig modelConfig) { - return null; - } - - @Override - public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) { - return null; - } - - } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java index 0b716ebb96..f3b1f3fac6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/SampleAlgo.java @@ -48,6 +48,11 @@ public void close() { sampleParam = DEFAULT_SAMPLE_PARAM; } + @Override + public boolean isModelReady() { + return true; + } + @Override public MLOutput predict(MLInput mlInput) { AtomicReference sum = new AtomicReference<>((double) 0); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index a811548629..8d6130566a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -80,7 +80,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - AwsConnector.awsConnectorBuilder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + AwsConnector.awsConnectorBuilder().name("test connector").protocol("http").version("1").actions(Arrays.asList(predictAction)).build(); } @Test @@ -99,7 +99,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio .build(); Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); @@ -124,7 +124,7 @@ public void executePredict_RemoteInferenceInput() throws IOException { .build(); Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 10bff8e82b..9c3057b3a5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -67,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); } @@ -104,7 +104,7 @@ public void processOutput_NoPostprocessFunction() throws IOException { .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); ModelTensors tensors = ConnectorUtils.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of()); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); @@ -126,7 +126,7 @@ public void processOutput_PostprocessFunction() throws IOException { .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of()); Assert.assertEquals(1, tensors.getMlModelTensors().size()); @@ -153,7 +153,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 4d1bbec749..8d04603d2a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -66,7 +66,7 @@ public void invokeRemoteModel_WrongHttpMethod() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @@ -79,7 +79,7 @@ public void executePredict_RemoteInferenceInput() throws IOException { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); @@ -103,7 +103,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -129,7 +129,7 @@ public void executePredict_TextDocsInput() throws IOException { .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java new file mode 100644 index 0000000000..bef3e1da71 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import com.google.common.collect.ImmutableMap; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; + +import java.util.Arrays; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class RemoteModelTest { + + @Mock + MLInput mlInput; + + @Mock + MLModel mlModel; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + RemoteModel remoteModel; + Encryptor encryptor; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + remoteModel = new RemoteModel(); + encryptor = spy(new EncryptorImpl("0000000000000001")); + } + + @Test + public void predict_ModelNotDeployed() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Model not ready yet"); + remoteModel.predict(mlInput, mlModel); + } + + @Test + public void predict_NullConnectorExecutor() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("Model not ready yet"); + remoteModel.predict(mlInput); + } + + @Test + public void predict_ModelDeployed_WrongInput() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("Wrong input type"); + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.predict(mlInput); + } + + @Test + public void initModel_RuntimeException() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Tag mismatch!"); + Connector connector = createConnector(null); + when(mlModel.getConnector()).thenReturn(connector); + doThrow(new IllegalArgumentException("Tag mismatch!")).when(encryptor).decrypt(any()); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + } + + @Test + public void initModel_NullHeader() { + Connector connector = createConnector(null); + when(mlModel.getConnector()).thenReturn(connector); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + Map decryptedHeaders = connector.getDecryptedHeaders(); + Assert.assertNull(decryptedHeaders); + } + + @Test + public void initModel_WithHeader() { + Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); + when(mlModel.getConnector()).thenReturn(connector); + remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + Map decryptedHeaders = connector.getDecryptedHeaders(); + RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); + Assert.assertNotNull(executor); + Assert.assertNull(decryptedHeaders); + Assert.assertNotNull(executor.getConnector().getDecryptedHeaders()); + Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size()); + Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization")); + + remoteModel.close(); + Assert.assertNull(remoteModel.getConnectorExecutor()); + } + + private Connector createConnector(Map headers) { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("wrong_method") + .url("http://test.com/mock") + .headers(headers) + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Connector connector = HttpConnector.builder() + .name("test connector") + .protocol(ConnectorProtocols.HTTP) + .version("1") + .credential(ImmutableMap.of("key", encryptor.encrypt("test_api_key"))) + .actions(Arrays.asList(predictAction)) + .build(); + return connector; + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 043428a483..20c96d8c6c 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -120,6 +120,7 @@ public void dispatchTask(MLPredictionTaskRequest request, TransportService trans ActionListener actionListener = ActionListener.wrap(node -> { if (clusterService.localNode().getId().equals(node.getId())) { log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId()); + request.setDispatchTask(false); executeTask(request, listener); } else { log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId()); @@ -129,8 +130,13 @@ public void dispatchTask(MLPredictionTaskRequest request, TransportService trans }, e -> { listener.onFailure(e); }); String[] workerNodes = mlModelManager.getWorkerNodes(modelId, true); if (workerNodes == null || workerNodes.length == 0) { - if (algorithm == FunctionName.TEXT_EMBEDDING) { - listener.onFailure(new IllegalArgumentException("model not deployed")); + if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + listener + .onFailure( + new IllegalArgumentException( + "Model not ready yet. Please run this first: POST /_plugins/_ml/models/" + modelId + "/_deploy" + ) + ); return; } else { workerNodes = nodeHelper.getEligibleNodeIds(); @@ -204,6 +210,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe try { Predictable predictor = mlModelManager.getPredictor(modelId); if (predictor != null) { + if (!predictor.isModelReady()) { + throw new IllegalArgumentException("Model not ready: " + modelId); + } MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput)); if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); @@ -214,8 +223,8 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe MLTaskResponse response = MLTaskResponse.builder().output(output).build(); internalListener.onResponse(response); return; - } else if (algorithm == FunctionName.TEXT_EMBEDDING) { - throw new IllegalArgumentException("model not deployed"); + } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + throw new IllegalArgumentException("Model not ready to be used: " + modelId); } } catch (Exception e) { handlePredictFailure(mlTask, internalListener, e, false);