From 9a7075d0b37198e4e1999204ba7c97f1f31e3ea1 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 17 Nov 2023 14:49:42 +0530 Subject: [PATCH 1/4] Backport multiple PRs to main from 2.x (#1652) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix parameter name in preprocess function; fix remote model function … (#1362) * fix parameter name in preprocess function; fix remote model function name Signed-off-by: Yaliang Wu * fix failed unit test Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu * throw exception when model group not found during update request (#1447) Signed-off-by: Bhavana Ramaram * add status code to model tensor (#1443) (#1453) Signed-off-by: Yaliang Wu * register new versions to a model group based on the name provided (#1452) Signed-off-by: Bhavana Ramaram * fixing metrics correlation algorithm (#1448) * fixing metrics correlation algorithm Signed-off-by: Dhrubo Saha * if model version fails to register, update model group accordingly (#1463) * if model version fails to register, update model group accordingly Signed-off-by: Bhavana Ramaram * Update Model API (#1350) * Update Model API POC Signed-off-by: Sicheng Song * Using GetRequest to get model Signed-off-by: Sicheng Song * Finalize model update API Signed-off-by: Sicheng Song * Fix compile Signed-off-by: Sicheng Song * Fix compileTest Signed-off-by: Sicheng Song * Add Unit Test Cases for Update Model API Signed-off-by: Sicheng Song * Tune back test coverage thereshold Signed-off-by: Sicheng Song * Add more unit tests on Update model API Signed-off-by: Sicheng Song * Add unit test for TransportUpdateModelAction class Signed-off-by: Sicheng Song * Fix a test error Signed-off-by: Sicheng Song * Change exception thrown to failure response Signed-off-by: Sicheng Song * Move the function judgement to the outter block Signed-off-by: Sicheng Song * Check if model is undeployed before update model Signed-off-by: Sicheng Song * Add more unit test for update model API Signed-off-by: Sicheng Song * Fix unit test due to blocking java 11 CI workflow Signed-off-by: Sicheng Song * Enabling auto bumping model version during registering to a new model group and address reviewers' other concern Signed-off-by: Sicheng Song * Autobump new model groups' latest version when register to a new model Signed-off-by: Sicheng Song * Change the REST API method from POST to PUT Signed-off-by: Sicheng Song * Change the update REST API endpoint Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * Add a setting to control the update connector API (#1465) * Add a setting to control the update connector API Signed-off-by: Sicheng Song * Enabling the update connnector setting in unit test Signed-off-by: Sicheng Song * Enabling the update connnector setting in corresponding unit test Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * fix update connector API (#1484) * fix update connector API Signed-off-by: Yaliang Wu * Performance enhacement for predict action by caching model info (#1472) (#1508) * Performance enhacement for predict action by caching model info Signed-off-by: zane-neo * Add context.restore() to avoid missing info Signed-off-by: zane-neo --------- Signed-off-by: zane-neo (cherry picked from commit a985f6ec6dc280072b7045dcf4851959aa575c54) Co-authored-by: zane-neo * fix failed ut from PR 1472 (#1479) (#1510) * fix failed ut from PR 1472 Signed-off-by: Yaliang Wu * exclude class for low coverage Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit da5d82942385c34544016cf361b517c2bb3d36c4) Co-authored-by: Yaliang Wu * [Backport to 2.11] throw exception if remote model doesn't return 2xx status code; fix p… (#1477) (#1509) * throw exception if remote model doesn't return 2xx status code; fix p… (#1473) * throw exception if remote model doesn't return 2xx status code; fix predict runner Signed-off-by: Yaliang Wu * fix kmeans model deploy bug Signed-off-by: Yaliang Wu * support multiple docs for remote embedding model Signed-off-by: Yaliang Wu * fix ut Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu * fix wrong class Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit 201c8a89b07126c5da9ed2e743c7f1b0e4806e12) Co-authored-by: Yaliang Wu * fix no worker node exception for remote embedding model (#1482) (#1511) * fix no worker node exception for remote embedding model Signed-off-by: Yaliang Wu * only add model info to cache if model cache exist Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit 6f83b9fee002026d7a8d0fa3550fe8cf80b30371) Co-authored-by: Yaliang Wu * fix for delete model group API throwing incorrect error when model index not created (#1485) (#1486) (#1512) * fix for delete model group API throwing incorrect error when model index not created Signed-off-by: Bhavana Ramaram (cherry picked from commit 60ef0fd6dfeda18729f1dc2ec6ea9c0418c6ff69) Co-authored-by: Bhavana Ramaram (cherry picked from commit 55446819b7686e14cf9e1d10edf7956ed57148c7) Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> * fix no worker node error on multi-node cluster (#1487) (#1513) Signed-off-by: Yaliang Wu (cherry picked from commit cea1cd675cd95e37c29a3bd88c1cd9d58e81b20a) Co-authored-by: Yaliang Wu * add prefix to show the error is from remote service (#1499) (#1515) Signed-off-by: Yaliang Wu (cherry picked from commit 3897ad179437e683033d6918ebb6b4edf439dd4d) Co-authored-by: Yaliang Wu * fix multiple docs support (#1516) Signed-off-by: Yaliang Wu * adding another fix issue to the release note (#1498) (#1514) Signed-off-by: Dhrubo Saha (cherry picked from commit 440155c5c242113cd264b59818d1927b498c1480) Co-authored-by: Dhrubo Saha * add bedrockURL to trusted connector regex list (#1461) Signed-off-by: Bhavana Ramaram * return parsing exception 400 for parsing errors Signed-off-by: Xun Zhang * add more ut in restupdateconnector Signed-off-by: Xun Zhang * fix format violations Signed-off-by: Bhavana Ramaram * Fix model/connector update API to address security concern (#1595) * Fix model/connector update API to address appsec concern Signed-off-by: Sicheng Song * Fix compile and build failure Signed-off-by: Sicheng Song * Improve unit test coverage Signed-off-by: Sicheng Song * Fix spotless Signed-off-by: Sicheng Song * Merge update connector feature flag to remote inference feature flag Signed-off-by: Sicheng Song * Fix compile Signed-off-by: Sicheng Song * Fix exception status Signed-off-by: Sicheng Song * Keep fixing exception status Signed-off-by: Sicheng Song * Spotless fix Signed-off-by: Sicheng Song * Add UT on parsing exception Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * change XContentFactory to MediaTypeRegistry builder in MLRegisterModelInputTest class Signed-off-by: Bhavana Ramaram --------- Signed-off-by: Yaliang Wu Signed-off-by: Bhavana Ramaram Signed-off-by: Dhrubo Saha Signed-off-by: Sicheng Song Signed-off-by: Xun Zhang Co-authored-by: Yaliang Wu Co-authored-by: Dhrubo Saha Co-authored-by: Sicheng Song Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Co-authored-by: zane-neo Co-authored-by: Xun Zhang --- .../org/opensearch/ml/common/CommonValue.java | 1 + .../org/opensearch/ml/common/MLModel.java | 6 +- .../ml/common/connector/Connector.java | 2 + .../ml/common/connector/HttpConnector.java | 33 + .../model/MetricsCorrelationModelConfig.java | 5 + .../ml/common/output/model/ModelTensors.java | 9 + .../connector/MLCreateConnectorInput.java | 27 +- .../connector/MLUpdateConnectorRequest.java | 18 +- .../transport/model/MLUpdateModelAction.java | 18 + .../transport/model/MLUpdateModelInput.java | 155 ++++ .../transport/model/MLUpdateModelRequest.java | 75 ++ .../register/MLRegisterModelInput.java | 31 +- .../MLRegisterModelMetaInput.java | 18 +- .../MetricsCorrelationModelConfigTests.java | 84 ++ .../MLUpdateConnectorRequestTests.java | 31 +- .../model/MLUpdateModelInputTest.java | 163 ++++ .../model/MLUpdateModelRequestTest.java | 121 +++ .../register/MLRegisterModelInputTest.java | 55 +- .../MLRegisterModelMetaInputTest.java | 2 +- .../MLRegisterModelMetaRequestTest.java | 2 +- .../ml/engine/algorithms/DLModelExecute.java | 2 +- .../MetricsCorrelation.java | 26 +- .../remote/AwsConnectorExecutor.java | 6 + .../algorithms/remote/ConnectorUtils.java | 6 +- .../remote/HttpJsonConnectorExecutor.java | 10 + .../remote/RemoteConnectorExecutor.java | 28 +- .../MetricsCorrelationTest.java | 121 ++- .../remote/AwsConnectorExecutorTest.java | 53 +- .../remote/HttpJsonConnectorExecutorTest.java | 42 +- plugin/build.gradle | 1 + .../DeleteConnectorTransportAction.java | 11 +- .../GetConnectorTransportAction.java | 20 +- .../UpdateConnectorTransportAction.java | 49 +- .../DeleteModelGroupTransportAction.java | 33 +- .../TransportUpdateModelGroupAction.java | 61 +- .../models/UpdateModelTransportAction.java | 399 +++++++++ .../TransportPredictionTaskAction.java | 101 ++- .../TransportRegisterModelAction.java | 72 +- .../TransportRegisterModelMetaAction.java | 74 +- .../helper/ConnectorAccessControlHelper.java | 60 +- .../org/opensearch/ml/model/MLModelCache.java | 20 +- .../ml/model/MLModelCacheHelper.java | 16 + .../ml/model/MLModelGroupManager.java | 68 +- .../opensearch/ml/model/MLModelManager.java | 63 +- .../ml/plugin/MachineLearningPlugin.java | 6 + .../ml/rest/RestMLUpdateConnectorAction.java | 21 +- .../ml/rest/RestMLUpdateModelAction.java | 75 ++ .../ml/settings/MLCommonsSettings.java | 3 +- .../ml/settings/MLFeatureEnabledSetting.java | 1 + .../ml/task/MLPredictTaskRunner.java | 28 +- .../DeleteConnectorTransportActionTests.java | 14 +- .../GetConnectorTransportActionTests.java | 2 +- ... UpdateConnectorTransportActionTests.java} | 156 +++- .../TransportUpdateModelGroupActionTests.java | 26 +- .../UpdateModelTransportActionTests.java | 847 ++++++++++++++++++ .../TransportRegisterModelActionTests.java | 124 ++- ...TransportRegisterModelMetaActionTests.java | 91 +- .../ConnectorAccessControlHelperTests.java | 6 +- .../ml/model/MLModelCacheHelperTests.java | 12 + .../ml/model/MLModelGroupManagerTests.java | 88 +- .../RestMLUpdateConnectorActionTests.java | 33 +- .../ml/rest/RestMLUpdateModelActionTests.java | 191 ++++ .../ml/task/MLPredictTaskRunnerTests.java | 28 +- ...search-ml-common.release-notes-2.11.0.0.md | 1 + 64 files changed, 3595 insertions(+), 356 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java rename plugin/src/test/java/org/opensearch/ml/action/connector/{TransportUpdateConnectorActionTests.java => UpdateConnectorTransportActionTests.java} (69%) create mode 100644 plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 3b981bcf20..84c3a96712 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -18,6 +18,7 @@ public class CommonValue { public static Integer NO_SCHEMA_VERSION = 0; + public static final String REMOTE_SERVICE_ERROR = "Error from remote service: "; public static final String USER = "user"; public static final String META = "_meta"; public static final String SCHEMA_VERSION_FIELD = "schema_version"; diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index ae28066d5a..78f6f4ac60 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -191,7 +191,11 @@ public MLModel(StreamInput input) throws IOException{ modelContentSizeInBytes = input.readOptionalLong(); modelContentHash = input.readOptionalString(); if (input.readBoolean()) { - modelConfig = new TextEmbeddingModelConfig(input); + if (algorithm.equals(FunctionName.METRICS_CORRELATION)) { + modelConfig = new MetricsCorrelationModelConfig(input); + } else { + modelConfig = new TextEmbeddingModelConfig(input); + } } createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index b3f9aafad8..0652a83421 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -30,6 +30,7 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -69,6 +70,7 @@ public interface Connector extends ToXContentObject, Writeable { void writeTo(StreamOutput out) throws IOException; + void update(MLCreateConnectorInput updateContent, Function function); void parseResponse(T orElse, List modelTensors, boolean b) throws IOException; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 96f6c018ec..ef0e4bf4a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -34,6 +34,7 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @Log4j2 @NoArgsConstructor @@ -248,6 +249,38 @@ public void writeTo(StreamOutput out) throws IOException { } } + @Override + public void update(MLCreateConnectorInput updateContent, Function function) { + if (updateContent.getName() != null) { + this.name = updateContent.getName(); + } + if (updateContent.getDescription() != null) { + this.description = updateContent.getDescription(); + } + if (updateContent.getVersion() != null) { + this.version = updateContent.getVersion(); + } + if (updateContent.getProtocol() != null) { + this.protocol = updateContent.getProtocol(); + } + if (updateContent.getParameters() != null && updateContent.getParameters().size() > 0) { + this.parameters = updateContent.getParameters(); + } + if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) { + this.credential = updateContent.getCredential(); + encrypt(function); + } + if (updateContent.getActions() != null) { + this.actions = updateContent.getActions(); + } + if (updateContent.getBackendRoles() != null) { + this.backendRoles = updateContent.getBackendRoles(); + } + if (updateContent.getAccess() != null) { + this.access = updateContent.getAccess(); + } + } + @Override public T createPredictPayload(Map parameters) { Optional predictAction = findPredictAction(); diff --git a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java index 4f26e4b4d2..e1c9203cae 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.Getter; import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; @@ -28,6 +29,10 @@ public MetricsCorrelationModelConfig(String modelType, String allConfig) { super(modelType, allConfig); } + public MetricsCorrelationModelConfig(StreamInput in) throws IOException{ + super(in); + } + @Override public String getWriteableName() { return PARSE_FIELD_NAME; diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 9073345550..03b0ce5fca 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -7,6 +7,7 @@ import lombok.Builder; import lombok.Getter; +import lombok.Setter; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; @@ -24,7 +25,10 @@ @Getter public class ModelTensors implements Writeable, ToXContentObject { public static final String OUTPUT_FIELD = "output"; + public static final String STATUS_CODE_FIELD = "status_code"; private List mlModelTensors; + @Setter + private Integer statusCode; @Builder public ModelTensors(List mlModelTensors) { @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } + if (statusCode != null) { + builder.field(STATUS_CODE_FIELD, statusCode); + } builder.endObject(); return builder; } @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException { mlModelTensors.add(new ModelTensor(in)); } } + statusCode = in.readOptionalInt(); } @Override @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalInt(statusCode); } public void filter(ModelResultFilter resultFilter) { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 9d5f7c88de..9d9879daec 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -56,6 +56,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { private Boolean addAllBackendRoles; private AccessMode access; private boolean dryRun = false; + private boolean updateConnector = false; @Builder(toBuilder = true) public MLCreateConnectorInput(String name, @@ -68,9 +69,10 @@ public MLCreateConnectorInput(String name, List backendRoles, Boolean addAllBackendRoles, AccessMode access, - boolean dryRun + boolean dryRun, + boolean updateConnector ) { - if (!dryRun) { + if (!dryRun && !updateConnector) { if (name == null) { throw new IllegalArgumentException("Connector name is null"); } @@ -92,9 +94,14 @@ public MLCreateConnectorInput(String name, this.addAllBackendRoles = addAllBackendRoles; this.access = access; this.dryRun = dryRun; + this.updateConnector = updateConnector; } public static MLCreateConnectorInput parse(XContentParser parser) throws IOException { + return parse(parser, false); + } + + public static MLCreateConnectorInput parse(XContentParser parser, boolean updateConnector) throws IOException { String name = null; String description = null; String version = null; @@ -159,7 +166,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep break; } } - return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun); + return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector); } @Override @@ -201,10 +208,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput output) throws IOException { - output.writeString(name); + output.writeOptionalString(name); output.writeOptionalString(description); - output.writeString(version); - output.writeString(protocol); + output.writeOptionalString(version); + output.writeOptionalString(protocol); if (parameters != null) { output.writeBoolean(true); output.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); @@ -240,13 +247,14 @@ public void writeTo(StreamOutput output) throws IOException { output.writeBoolean(false); } output.writeBoolean(dryRun); + output.writeBoolean(updateConnector); } public MLCreateConnectorInput(StreamInput input) throws IOException { - name = input.readString(); + name = input.readOptionalString(); description = input.readOptionalString(); - version = input.readString(); - protocol = input.readString(); + version = input.readOptionalString(); + protocol = input.readOptionalString(); if (input.readBoolean()) { parameters = input.readMap(s -> s.readString(), s -> s.readString()); } @@ -268,5 +276,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { this.access = input.readEnum(AccessMode.class); } dryRun = input.readBoolean(); + updateConnector = input.readBoolean(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java index ced3646d13..089180cdc5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -19,17 +19,16 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Map; import static org.opensearch.action.ValidateActions.addValidationError; @Getter public class MLUpdateConnectorRequest extends ActionRequest { String connectorId; - Map updateContent; + MLCreateConnectorInput updateContent; @Builder - public MLUpdateConnectorRequest(String connectorId, Map updateContent) { + public MLUpdateConnectorRequest(String connectorId, MLCreateConnectorInput updateContent) { this.connectorId = connectorId; this.updateContent = updateContent; } @@ -37,14 +36,14 @@ public MLUpdateConnectorRequest(String connectorId, Map updateCo public MLUpdateConnectorRequest(StreamInput in) throws IOException { super(in); this.connectorId = in.readString(); - this.updateContent = in.readMap(); + this.updateContent = new MLCreateConnectorInput(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(this.connectorId); - out.writeMap(this.getUpdateContent()); + this.updateContent.writeTo(out); } @Override @@ -55,14 +54,17 @@ public ActionRequestValidationException validate() { exception = addValidationError("ML connector id can't be null", exception); } + if (updateContent == null) { + exception = addValidationError("Update connector content can't be null", exception); + } + return exception; } public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException { - Map dataAsMap = null; - dataAsMap = parser.map(); + MLCreateConnectorInput updateContent = MLCreateConnectorInput.parse(parser, true); - return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build(); + return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build(); } public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java new file mode 100644 index 0000000000..2d584a0e73 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class MLUpdateModelAction extends ActionType { + public static MLUpdateModelAction INSTANCE = new MLUpdateModelAction(); + public static final String NAME = "cluster:admin/opensearch/ml/models/update"; + + private MLUpdateModelAction() { + super(NAME, UpdateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java new file mode 100644 index 0000000000..ca0a2f70d4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.Data; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.connector.Connector.createConnector; + +@Data +public class MLUpdateModelInput implements ToXContentObject, Writeable { + + public static final String MODEL_ID_FIELD = "model_id"; // mandatory + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String MODEL_VERSION_FIELD = "model_version"; // optional + public static final String MODEL_NAME_FIELD = "name"; // optional + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional + public static final String MODEL_CONFIG_FIELD = "model_config"; // optional + public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional + + @Getter + private String modelId; + private String description; + private String version; + private String name; + private String modelGroupId; + private MLModelConfig modelConfig; + private String connectorId; + + @Builder(toBuilder = true) + public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) { + this.modelId = modelId; + this.description = description; + this.version = version; + this.name = name; + this.modelGroupId = modelGroupId; + this.modelConfig = modelConfig; + this.connectorId = connectorId; + } + + public MLUpdateModelInput(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.description = in.readOptionalString(); + this.version = in.readOptionalString(); + this.name = in.readOptionalString(); + this.modelGroupId = in.readOptionalString(); + if (in.readBoolean()) { + modelConfig = new TextEmbeddingModelConfig(in); + } + this.connectorId = in.readOptionalString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID_FIELD, modelId); + if (name != null) { + builder.field(MODEL_NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (version != null) { + builder.field(MODEL_VERSION_FIELD, version); + } + if (modelGroupId != null) { + builder.field(MODEL_GROUP_ID_FIELD, modelGroupId); + } + if (modelConfig != null) { + builder.field(MODEL_CONFIG_FIELD, modelConfig); + } + if (connectorId != null) { + builder.field(CONNECTOR_ID_FIELD, connectorId); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalString(description); + out.writeOptionalString(version); + out.writeOptionalString(name); + out.writeOptionalString(modelGroupId); + if (modelConfig != null) { + out.writeBoolean(true); + modelConfig.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(connectorId); + } + + public static MLUpdateModelInput parse(XContentParser parser) throws IOException { + String modelId = null; + String description = null; + String version = null; + String name = null; + String modelGroupId = null; + MLModelConfig modelConfig = null; + String connectorId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case MODEL_NAME_FIELD: + name = parser.text(); + break; + case MODEL_VERSION_FIELD: + version = parser.text(); + break; + case MODEL_GROUP_ID_FIELD: + modelGroupId = parser.text(); + break; + case MODEL_CONFIG_FIELD: + modelConfig = TextEmbeddingModelConfig.parse(parser); + break; + case CONNECTOR_ID_FIELD: + connectorId = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + // Model ID can only be set through RestRequest. Model version can only be set automatically. + return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId); + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java new file mode 100644 index 0000000000..b589f71ed4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLUpdateModelRequest extends ActionRequest { + + MLUpdateModelInput updateModelInput; + + @Builder + public MLUpdateModelRequest(MLUpdateModelInput updateModelInput) { + this.updateModelInput = updateModelInput; + } + + public MLUpdateModelRequest(StreamInput in) throws IOException { + super(in); + updateModelInput = new MLUpdateModelInput(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (updateModelInput == null) { + exception = addValidationError("Update Model Input can't be null", exception); + } + + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + this.updateModelInput.writeTo(out); + } + + public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ + if (actionRequest instanceof MLUpdateModelRequest) { + return (MLUpdateModelRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUpdateModelRequest(in); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); + } + } +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index e79a09c5b2..a871332f95 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -18,6 +18,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; @@ -52,7 +53,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String ACCESS_MODE_FIELD = "access_mode"; public static final String BACKEND_ROLES_FIELD = "backend_roles"; public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; - + public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; private FunctionName functionName; private String modelName; private String modelGroupId; @@ -72,6 +73,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private List backendRoles; private Boolean addAllBackendRoles; private AccessMode accessMode; + private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, @@ -89,7 +91,8 @@ public MLRegisterModelInput(FunctionName functionName, String connectorId, List backendRoles, Boolean addAllBackendRoles, - AccessMode accessMode + AccessMode accessMode, + Boolean doesVersionCreateModelGroup ) { if (functionName == null) { this.functionName = FunctionName.TEXT_EMBEDDING; @@ -122,6 +125,7 @@ public MLRegisterModelInput(FunctionName functionName, this.backendRoles = backendRoles; this.addAllBackendRoles = addAllBackendRoles; this.accessMode = accessMode; + this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } @@ -137,7 +141,11 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.modelFormat = in.readEnum(MLModelFormat.class); } if (in.readBoolean()) { - this.modelConfig = new TextEmbeddingModelConfig(in); + if (this.functionName.equals(FunctionName.METRICS_CORRELATION)) { + this.modelConfig = new MetricsCorrelationModelConfig(in); + } else { + this.modelConfig = new TextEmbeddingModelConfig(in); + } } this.deployModel = in.readBoolean(); this.modelNodeIds = in.readOptionalStringArray(); @@ -152,6 +160,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { if (in.readBoolean()) { this.accessMode = in.readEnum(AccessMode.class); } + this.doesVersionCreateModelGroup = in.readOptionalBoolean(); } @Override @@ -197,6 +206,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalBoolean(doesVersionCreateModelGroup); } @Override @@ -244,6 +254,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (accessMode != null) { builder.field(ACCESS_MODE_FIELD, accessMode); } + if (doesVersionCreateModelGroup != null) { + builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); + } builder.endObject(); return builder; } @@ -262,6 +275,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName List backendRoles = new ArrayList<>(); Boolean addAllBackendRoles = null; AccessMode accessMode = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -313,12 +327,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -337,6 +354,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo List backendRoles = new ArrayList<>(); AccessMode accessMode = null; Boolean addAllBackendRoles = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -395,11 +413,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index d23f00caf7..ecb03d9bb6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -46,6 +46,8 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional public static final String ACCESS_MODE = "access_mode"; //optional public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; + private FunctionName functionName; private String name; @@ -65,11 +67,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private List backendRoles; private AccessMode accessMode; private Boolean isAddAllBackendRoles; + private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, AccessMode accessMode, - Boolean isAddAllBackendRoles) { + Boolean isAddAllBackendRoles, + Boolean doesVersionCreateModelGroup) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -103,6 +107,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.backendRoles = backendRoles; this.accessMode = accessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; + this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } public MLRegisterModelMetaInput(StreamInput in) throws IOException{ @@ -128,6 +133,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{ accessMode = in.readEnum(AccessMode.class); } this.isAddAllBackendRoles = in.readOptionalBoolean(); + this.doesVersionCreateModelGroup = in.readOptionalBoolean(); } @Override @@ -171,6 +177,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isAddAllBackendRoles); + out.writeOptionalBoolean(doesVersionCreateModelGroup); } @Override @@ -206,6 +213,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isAddAllBackendRoles != null) { builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles); } + if (doesVersionCreateModelGroup != null) { + builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); + } builder.endObject(); return builder; } @@ -225,6 +235,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc List backendRoles = null; AccessMode accessMode = null; Boolean isAddAllBackendRoles = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -277,12 +288,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case ADD_ALL_BACKEND_ROLES: isAddAllBackendRoles = parser.booleanValue(); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles); + return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup); } } diff --git a/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java new file mode 100644 index 0000000000..4700039939 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class MetricsCorrelationModelConfigTests { + + MetricsCorrelationModelConfig config; + Function function; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + config = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + function = parser -> { + try { + return MetricsCorrelationModelConfig.parse(parser); + } catch (IOException e) { + throw new RuntimeException("Failed to parse MetricsCorrelationModelConfig", e); + } + }; + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + config.toXContent(builder, EMPTY_PARAMS); + String configContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + } + + @Test + public void nullFields_ModelType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("model type is null"); + config = MetricsCorrelationModelConfig.builder() + .build(); + } + + @Test + public void parse() throws IOException { + String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + TestHelper.testParseFromString(config, content, function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(config); + } + + public void readInputStream(MetricsCorrelationModelConfig config) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + config.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MetricsCorrelationModelConfig parsedConfig = new MetricsCorrelationModelConfig(streamInput); + assertEquals(config.getModelType(), parsedConfig.getModelType()); + assertEquals(config.getAllConfig(), parsedConfig.getAllConfig()); + assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java index e017009983..44e970f95c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -7,38 +7,37 @@ import org.junit.Before; import org.junit.Test; -import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestRequest; +import org.opensearch.search.SearchModule; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Map; +import java.util.Collections; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; -import static org.mockito.Mockito.when; +import static org.junit.Assert.assertTrue; public class MLUpdateConnectorRequestTests { private String connectorId; - private Map updateContent; + private MLCreateConnectorInput updateContent; private MLUpdateConnectorRequest mlUpdateConnectorRequest; - @Mock - XContentParser parser; - @Before public void setUp() { MockitoAnnotations.openMocks(this); this.connectorId = "test-connector_id"; - this.updateContent = Map.of("description", "new description"); + this.updateContent = MLCreateConnectorInput.builder().description("new description").updateConnector(true).build(); mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() .connectorId(connectorId) .updateContent(updateContent) @@ -64,18 +63,20 @@ public void validate_Exception_NullConnectorId() { MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build(); Exception exception = updateConnectorRequest.validate(); - assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage()); + assertEquals("Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", exception.getMessage()); } @Test public void parse_success() throws IOException { - RestRequest.Method method = RestRequest.Method.POST; - final Map updatefields = Map.of("version", "new version", "description", "new description"); - when(parser.map()).thenReturn(updatefields); - + String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); assertEquals(updateConnectorRequest.getConnectorId(), connectorId); - assertEquals(updateConnectorRequest.getUpdateContent(), updatefields); + assertTrue(updateConnectorRequest.getUpdateContent().isUpdateConnector()); + assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion()); + assertEquals("new description", updateConnectorRequest.getUpdateContent().getDescription()); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java new file mode 100644 index 0000000000..6bafe81692 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +public class MLUpdateModelInputTest { + + private MLUpdateModelInput updateModelInput; + private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}"; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .version("2") + .name("name") + .description("description") + .modelConfig(config) + .connectorId("test-connector_id") + .build(); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(updateModelInput, parsedInput -> { + assertEquals("test-model_id", parsedInput.getModelId()); + assertEquals(updateModelInput.getName(), parsedInput.getName()); + }); + } + + @Test + public void readInputStream_SuccessWithNullFields() throws IOException { + updateModelInput.setModelConfig(null); + readInputStream(updateModelInput, parsedInput -> { + assertNull(parsedInput.getModelConfig()); + }); + } + + @Test + public void testToXContent() throws Exception { + String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedInputStr, jsonStr); + } + + @Test + public void testToXContent_Incomplete() throws Exception { + String expectedIncompleteInputStr = + "{\"model_id\":\"test-model_id\"}"; + updateModelInput.setDescription(null); + updateModelInput.setVersion(null); + updateModelInput.setName(null); + updateModelInput.setModelGroupId(null); + updateModelInput.setModelConfig(null); + updateModelInput.setConnectorId(null); + String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedIncompleteInputStr, jsonStr); + } + + @Test + public void parse_Success() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> { + assertEquals("name", parsedInput.getName()); + }); + } + + @Test + public void parse_WithNullFieldWithoutModel() throws Exception { + exceptionRule.expect(IllegalStateException.class); + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parse_WithIllegalFieldWithoutModel() throws Exception { + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + parser.nextToken(); + MLUpdateModelInput parsedInput = MLUpdateModelInput.parse(parser); + verify.accept(parsedInput); + } + + private void readInputStream(MLUpdateModelInput input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUpdateModelInput parsedInput = new MLUpdateModelInput(streamInput); + verify.accept(parsedInput); + } + + private String serializationWithToXContent(MLUpdateModelInput input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + return builder.toString(); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java new file mode 100644 index 0000000000..cadf865b1c --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.model; + +import org.junit.Before; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; + +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + + +public class MLUpdateModelRequestTest { + + private MLUpdateModelRequest updateModelRequest; + + @Before + public void setUp(){ + MockitoAnnotations.openMocks(this); + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLUpdateModelInput updateModelInput = MLUpdateModelInput.builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .name("name") + .description("description") + .modelConfig(config) + .build(); + + updateModelRequest = MLUpdateModelRequest.builder() + .updateModelInput(updateModelInput) + .build(); + + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + updateModelRequest.writeTo(bytesStreamOutput); + MLUpdateModelRequest parsedUpdateRequest = new MLUpdateModelRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals("test-model_id", parsedUpdateRequest.getUpdateModelInput().getModelId()); + assertEquals("name", parsedUpdateRequest.getUpdateModelInput().getName()); + } + + @Test + public void validate_Success() { + assertNull(updateModelRequest.validate()); + } + + @Test + public void validate_Exception_NullModelInput() { + MLUpdateModelRequest updateModelRequest = MLUpdateModelRequest.builder().build(); + Exception exception = updateModelRequest.validate(); + + assertEquals("Validation Failed: 1: Update Model Input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + assertSame(MLUpdateModelRequest.fromActionRequest(updateModelRequest), updateModelRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + updateModelRequest.writeTo(out); + } + }; + MLUpdateModelRequest request = MLUpdateModelRequest.fromActionRequest(actionRequest); + assertNotSame(request, updateModelRequest); + assertEquals(updateModelRequest.getUpdateModelInput().getName(), request.getUpdateModelInput().getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLUpdateModelRequest.fromActionRequest(actionRequest); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 24a409bd44..6de3788243 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -22,6 +22,7 @@ import org.opensearch.ml.common.connector.HttpConnectorTest; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.search.SearchModule; @@ -74,7 +75,6 @@ public void setUp() throws Exception { .deployModel(true) .modelNodeIds(new String[]{"modelNodeIds" }) .build(); - } @Test @@ -257,6 +257,59 @@ public void readInputStream_WithInternalConnector() throws IOException { }); } + @Test + public void testMCorrInput() throws IOException { + String testString = "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + + MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + mcorrInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals(testString, jsonStr); + } + + @Test + public void readInputStream_MCorr() throws IOException { + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + + MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + readInputStream(mcorrInput, parsedInput -> { + assertEquals(parsedInput.getModelConfig().getModelType(), mcorrConfig.getModelType()); + assertEquals(parsedInput.getModelConfig().getAllConfig(), mcorrConfig.getAllConfig()); + assertEquals(parsedInput.getFunctionName(), FunctionName.METRICS_CORRELATION); + assertEquals(parsedInput.getModelName(), FunctionName.METRICS_CORRELATION.name()); + assertEquals(parsedInput.getModelGroupId(), modelGroupId); + }); + } + private void readInputStream(MLRegisterModelInput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index c9ace159ee..61e57d4ac6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -43,7 +43,7 @@ public void setup() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index 0c3a432d94..d7039780f0 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -33,7 +33,7 @@ public void setUp() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java index 4e6bd6bd69..fab052da19 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java @@ -120,7 +120,7 @@ public void close() { * @param modelId id of the model * @param modelName name of the model * @param version version of the model - * @param engine engine where model will be run. For now we are supporting only pytorch engine only. + * @param engine engine where model will be run. For now, we are supporting only pytorch engine only. */ private void loadModel(File modelZipFile, String modelId, String modelName, String version, String engine) { try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index 96aad3168e..ec2fc1d141 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -121,7 +121,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { if (modelId == null) { boolean hasModelGroupIndex = clusterService.state().getMetadata().hasIndex(ML_MODEL_GROUP_INDEX); - if (!hasModelGroupIndex) { // Create model group index if doesn't exist + if (!hasModelGroupIndex) { // Create model group index if it doesn't exist try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX).mapping(ML_MODEL_GROUP_INDEX_MAPPING); CreateIndexResponse createIndexResponse = client.admin().indices().create(request).actionGet(1000); @@ -176,7 +176,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { ); } }, e -> { log.error("Failed to get model", e); }); - client.get(getModelRequest, ActionListener.runBefore(listener, () -> context.restore())); + client.get(getModelRequest, ActionListener.runBefore(listener, context::restore)); } } } else { @@ -197,10 +197,24 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { waitUntil(() -> { if (modelId != null) { MLModelState modelState = getModel(modelId).getModelState(); - return modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED; + if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED) { + log.info("Model deployed: " + modelState); + return true; + } else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) { + log.info("Model not deployed: " + modelState); + deployModel( + modelId, + ActionListener + .wrap( + deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), + e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e) + ) + ); + return false; + } } return false; - }, 10, TimeUnit.SECONDS); + }, 120, TimeUnit.SECONDS); Output djlOutput; try { @@ -253,7 +267,7 @@ void registerModel(ActionListener listener) throws Inte log.error("Failed to Register Model", e); listener.onFailure(e); })); - }, e -> { listener.onFailure(e); }), () -> context.restore())); + }, listener::onFailure), context::restore)); } catch (IOException e) { throw new MLException(e); } @@ -322,6 +336,8 @@ public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, } sum += timeInMillis; timeInMillis = Math.min(AWAIT_BUSY_THRESHOLD, timeInMillis * 2); + + log.info("Waiting... Time elapsed: " + sum + "ms"); } timeInMillis = maxTimeInMillis - sum; try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 2c18b363f8..1472e9bbc9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; import static software.amazon.awssdk.http.SdkHttpMethod.POST; @@ -86,6 +87,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { return httpClient.prepareRequest(executeRequest).call(); }); + int statusCode = response.httpResponse().statusCode(); AbortableInputStream body = null; if (response.responseBody().isPresent()) { @@ -104,8 +106,12 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); } String modelResponse = responseBuilder.toString(); + if (statusCode < 200 || statusCode >= 300) { + throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); + } ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCode); tensorOutputs.add(tensors); } catch (RuntimeException exception) { log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 92a44a3d91..88c43a969e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -115,7 +115,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput( docs.add(null); } } - if (preProcessFunction.contains("${parameters")) { + if (preProcessFunction.contains("${parameters.")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); preProcessFunction = substitutor.replace(preProcessFunction); } @@ -186,7 +186,9 @@ public static ModelTensors processOutput( // execute user defined painless script. Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent(); + boolean scriptReturnModelTensor = postProcessFunction != null + && processedResponse.isPresent() + && org.opensearch.ml.common.utils.StringUtils.isJson(response); if (responseFilter == null) { connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 5337fd9948..d08b2186ae 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; @@ -23,6 +24,8 @@ import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.util.EntityUtils; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; @@ -54,6 +57,7 @@ public HttpJsonConnectorExecutor(Connector connector) { public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { try { AtomicReference responseRef = new AtomicReference<>(""); + AtomicReference statusCodeRef = new AtomicReference<>(); HttpUriRequest request; switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { @@ -98,12 +102,18 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S String responseBody = EntityUtils.toString(responseEntity); EntityUtils.consume(responseEntity); responseRef.set(responseBody); + statusCodeRef.set(response.getStatusLine().getStatusCode()); } return null; }); String modelResponse = responseRef.get(); + Integer statusCode = statusCodeRef.get(); + if (statusCode < 200 || statusCode >= 300) { + throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode)); + } ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCode); tensorOutputs.add(tensors); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 0ff8f9a91e..aa471ba3fe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -32,15 +32,25 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - List textDocs = new ArrayList<>(textDocsInputDataSet.getDocs()); - preparePayloadAndInvokeRemoteModel( - MLInput - .builder() - .algorithm(FunctionName.TEXT_EMBEDDING) - .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) - .build(), - tensorOutputs - ); + int processedDocs = 0; + while (processedDocs < textDocsInputDataSet.getDocs().size()) { + List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); + List tempTensorOutputs = new ArrayList<>(); + preparePayloadAndInvokeRemoteModel( + MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) + .build(), + tempTensorOutputs + ); + int tensorCount = 0; + if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { + tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); + } + processedDocs += Math.max(tensorCount, 1); + tensorOutputs.addAll(tempTensorOutputs); + } } else { preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 02132687e3..223cb22289 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -18,6 +18,8 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; @@ -31,9 +33,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.search.TotalHits; import org.junit.Before; @@ -43,16 +48,27 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.Version; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; @@ -97,19 +113,25 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; +import org.opensearch.threadpool.ThreadPool; -//TODO: fix mockito error: Cannot mock/spy class org.opensearch.common.settings.Settings final class +import com.google.common.collect.ImmutableMap; -@Ignore public class MetricsCorrelationTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Mock Client client; - @Mock Settings settings; + @Mock private ClusterService clusterService; + + @Mock + ThreadPool threadPool; + + ThreadContext threadContext; + @Mock SearchRequest searchRequest; SearchResponse searchResponse; @@ -142,6 +164,8 @@ public class MetricsCorrelationTest { private final String modelId = "modelId"; private final String modelGroupId = "modelGroupId"; + final String USER_STRING = "myuser|role1,role2|myTenant"; + MLTask mlTask; Map params = new HashMap<>(); @@ -180,6 +204,16 @@ public void setUp() throws IOException, URISyntaxException { MockitoAnnotations.openMocks(this); metricsCorrelation = spy(new MetricsCorrelation(client, settings, clusterService)); + + settings = Settings.builder().build(); + ClusterState testClusterState = setupTestClusterState(); + when(clusterService.state()).thenReturn(testClusterState); + + threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + List inputData = new ArrayList<>(); inputData.add(new float[] { -1.0f, 2.0f, 3.0f }); inputData.add(new float[] { -1.0f, 2.0f, 3.0f }); @@ -261,16 +295,27 @@ public void setUp() throws IOException, URISyntaxException { extendedInput = MetricsCorrelationInput.builder().inputData(extendedInputData).build(); } + @Ignore @Test public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteException { - metricsCorrelation.initModel(model, params); MLModelGetResponse response = new MLModelGetResponse(model); ActionFuture mockedFuture = mock(ActionFuture.class); when(client.execute(any(MLModelGetAction.class), any(MLModelGetRequest.class))).thenReturn(mockedFuture); when(mockedFuture.actionGet(anyLong())).thenReturn(response); doAnswer(invocation -> { - MLModel smallModel = model.toBuilder().modelConfig(modelConfig).modelState(MLModelState.DEPLOYED).build(); + + MLModel smallModel = MLModel + .builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name(FunctionName.METRICS_CORRELATION.name()) + .modelId(modelId) + .modelGroupId(modelGroupId) + .algorithm(FunctionName.METRICS_CORRELATION) + .version(MCORR_ML_VERSION) + .modelConfig(modelConfig) + .modelState(MLModelState.UNDEPLOYED) + .build(); MLModelGetResponse responseTemp = new MLModelGetResponse(smallModel); ActionFuture mockedFutureTemp = mock(ActionFuture.class); MLTaskGetResponse taskResponse = new MLTaskGetResponse(mlTask); @@ -278,8 +323,8 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio when(client.execute(any(MLTaskGetAction.class), any(MLTaskGetRequest.class))).thenReturn(mockedFutureResponse); when(mockedFutureResponse.actionGet(anyLong())).thenReturn(taskResponse); when(mockedFutureTemp.actionGet(anyLong())).thenReturn(responseTemp); - metricsCorrelation.initModel(smallModel, params); + smallModel.toBuilder().modelState(MLModelState.DEPLOYED).build(); return null; }).when(client).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), isA(ActionListener.class)); @@ -289,6 +334,7 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); } + @Ignore @Test public void testExecuteWithModelInIndexAndEmptyOutput() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -349,6 +395,7 @@ public void testExecuteWithModelInIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + @Ignore @Test public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -389,6 +436,7 @@ public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + @Ignore @Test public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -435,6 +483,7 @@ public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws Execu assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + @Ignore @Test public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -476,6 +525,7 @@ public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + // working @Test public void testGetModel() { ActionFuture mockedFuture = mock(ActionFuture.class); @@ -508,6 +558,7 @@ public static XContentBuilder builder() throws IOException { return XContentBuilder.builder(XContentType.JSON.xContent()); } + // working @Test public void testSearchRequest() { String expectedIndex = CommonValue.ML_MODEL_INDEX; @@ -546,6 +597,7 @@ public void testSearchRequest() { assertEquals(MLModel.MODEL_VERSION_FIELD, versionQueryBuilder.fieldName()); } + @Ignore @Test public void testRegisterModel() throws InterruptedException { doAnswer(invocation -> { @@ -711,4 +763,61 @@ private SearchResponse createEmptySearchModelResponse() throws IOException { SearchResponse.Clusters.EMPTY ); } + + public static ClusterState setupTestClusterState() { + Set roleSet = new HashSet<>(); + roleSet.add(DiscoveryNodeRole.DATA_ROLE); + DiscoveryNode node = new DiscoveryNode( + "node", + new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT + ); + Metadata metadata = new Metadata.Builder() + .indices( + ImmutableMap + .builder() + .put( + ML_MODEL_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .put( + ML_MODEL_GROUP_INDEX, + IndexMetadata + .builder(ML_MODEL_GROUP_INDEX) + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .build() + ) + .build(); + return new ClusterState( + new ClusterName("test cluster"), + 123l, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(node).build(), + null, + Map.of(), + 0, + false + ); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 6ce1b00df6..b35f9b0eac 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; @@ -49,6 +50,7 @@ import software.amazon.awssdk.http.ExecutableHttpRequest; import software.amazon.awssdk.http.HttpExecuteResponse; import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpResponse; public class AwsConnectorExecutorTest { @@ -101,6 +103,49 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio exceptionRule.expectMessage("No response from model"); when(response.responseBody()).thenReturn(Optional.empty()); when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); + when(httpClient.prepareRequest(any())).thenReturn(httpRequest); + + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + } + + @Test + public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException { + exceptionRule.expect(OpenSearchStatusException.class); + exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}"); + String jsonString = "{\"message\":\"The security token included in the request is invalid\"}"; + InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); + AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); + when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(403); + when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); ConnectorAction predictAction = ConnectorAction @@ -135,6 +180,9 @@ public void executePredict_RemoteInferenceInput() throws IOException { InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpRequest.call()).thenReturn(response); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); @@ -177,6 +225,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException { AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); ConnectorAction predictAction = ConnectorAction @@ -202,7 +253,7 @@ public void executePredict_TextDocsInferenceInput() throws IOException { connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build(); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); ModelTensorOutput modelTensorOutput = executor .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 7a52d621f5..e91110b74e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -13,9 +13,12 @@ import java.util.Arrays; import org.apache.http.HttpEntity; +import org.apache.http.ProtocolVersion; +import org.apache.http.StatusLine; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -23,6 +26,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -99,6 +103,8 @@ public void executePredict_RemoteInferenceInput() throws IOException { when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor @@ -125,6 +131,8 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); Connector connector = HttpConnector .builder() .name("test connector") @@ -137,7 +145,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert @@ -147,6 +155,35 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti ); } + @Test + public void executePredict_TextDocsInput_LimitExceed() throws IOException { + exceptionRule.expect(OpenSearchStatusException.class); + exceptionRule.expectMessage("{\"message\": \"Too many requests\"}"); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}"); + when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + } + @Test public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; @@ -202,6 +239,8 @@ public void executePredict_TextDocsInput() throws IOException { + " \"total_tokens\": 5\n" + " }\n" + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); @@ -209,6 +248,7 @@ public void executePredict_TextDocsInput() throws IOException { ModelTensorOutput modelTensorOutput = executor .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert .assertArrayEquals( diff --git a/plugin/build.gradle b/plugin/build.gradle index b389360e0b..b8a4d47d22 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -272,6 +272,7 @@ List jacocoExclusions = [ 'org.opensearch.ml.action.deploy.TransportDeployModelAction', 'org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction', 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction', + 'org.opensearch.ml.action.prediction.TransportPredictionTaskAction.1', 'org.opensearch.ml.action.tasks.GetTaskTransportAction', 'org.opensearch.ml.action.tasks.SearchTaskTransportAction', 'org.opensearch.ml.model.MLModelManager', diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index 276f29259e..856fccb848 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -8,6 +8,10 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; @@ -77,11 +81,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener modelIds = new ArrayList<>(); + for (SearchHit hit : searchHits) { + modelIds.add(hit.getId()); + } actionListener .onFailure( new MLValidationException( searchHits.length - + " models are still using this connector, please delete or update the models first!" + + " models are still using this connector, please delete or update the models first: " + + Arrays.toString(modelIds.toArray(new String[0])) ) ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index b03e6028fa..7b953c341a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; @@ -19,11 +20,11 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; @@ -79,7 +80,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { log.error("Failed to get connector index", e); - actionListener.onFailure(new IllegalArgumentException("Fail to find connector")); + actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); } else { log.error("Failed to get ML connector " + connectorId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index d8a1d88a01..066ca5f8a7 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -7,7 +7,13 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.search.SearchRequest; @@ -16,16 +22,22 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.search.SearchHit; @@ -38,12 +50,14 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateConnectorTransportAction extends HandledTransportAction { Client client; ConnectorAccessControlHelper connectorAccessControlHelper; MLModelManager mlModelManager; + MLEngine mlEngine; + volatile List trustedConnectorEndpointsRegex; @Inject public UpdateConnectorTransportAction( @@ -51,25 +65,35 @@ public UpdateConnectorTransportAction( ActionFilters actionFilters, Client client, ConnectorAccessControlHelper connectorAccessControlHelper, - MLModelManager mlModelManager + MLModelManager mlModelManager, + Settings settings, + ClusterService clusterService, + MLEngine mlEngine ) { super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); this.client = client; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; + this.mlEngine = mlEngine; + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> trustedConnectorEndpointsRegex = it); } @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request); String connectorId = mlUpdateConnectorAction.getConnectorId(); - UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); - updateRequest.doc(mlUpdateConnectorAction.getUpdateContent()); - updateRequest.docAsUpsert(true); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> { + connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.wrap(connector -> { + boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector); if (Boolean.TRUE.equals(hasPermission)) { + connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); + updateRequest.doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); updateUndeployedConnector(connectorId, updateRequest, listener, context); } else { listener @@ -107,10 +131,17 @@ private void updateUndeployedConnector( client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); } else { log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); + List modelIds = new ArrayList<>(); + for (SearchHit hit : searchHits) { + modelIds.add(hit.getId()); + } listener .onFailure( - new MLValidationException( - searchHits.length + " models are still using this connector, please undeploy the models first!" + new OpenSearchStatusException( + searchHits.length + + " models are still using this connector, please undeploy the models first: " + + Arrays.toString(modelIds.toArray(new String[0])), + RestStatus.BAD_REQUEST ) ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index d7b9d8748b..63e43f8cb2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -25,7 +25,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest; @@ -84,26 +83,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (mlModels == null || mlModels.getHits().getTotalHits() == null || mlModels.getHits().getTotalHits().value == 0) { - client.delete(deleteRequest, new ActionListener() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - log.debug("Completed Delete Model Group Request, task id:{} deleted", modelGroupId); - wrappedListener.onResponse(deleteResponse); - } - - @Override - public void onFailure(Exception e) { - log.error("Failed to delete ML Model Group " + modelGroupId, e); - wrappedListener.onFailure(e); - } - }); + deleteModelGroup(deleteRequest, modelGroupId, wrappedListener); } else { throw new MLValidationException("Cannot delete the model group when it has associated model versions"); } }, e -> { if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + deleteModelGroup(deleteRequest, modelGroupId, wrappedListener); } else { log.error("Failed to search models with the specified Model Group Id " + modelGroupId, e); wrappedListener.onFailure(e); @@ -116,4 +103,20 @@ public void onFailure(Exception e) { })); } } + + private void deleteModelGroup(DeleteRequest deleteRequest, String modelGroupId, ActionListener actionListener) { + client.delete(deleteRequest, new ActionListener() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + log.debug("Completed Delete Model Group Request, task id:{} deleted", modelGroupId); + actionListener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete ML Model Group " + modelGroupId, e); + actionListener.onFailure(e); + } + }); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 494f197857..9d6dc32b86 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,7 +9,6 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLExceptionUtils.logException; -import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; @@ -90,39 +89,41 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); - client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { - if (modelGroup.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { + if (modelGroup.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); - updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user); + } else { + validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); } - } else { - wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); - } else { - logException("Failed to get model group", e, log); + updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user); + } catch (Exception e) { + log.error("Failed to parse ml model group" + modelGroup.getId(), e); wrappedListener.onFailure(e); } - })); - } catch (Exception e) { - logException("Failed to Update model group", e, log); - listener.onFailure(e); - } - } else { - validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); - updateModelGroup(modelGroupId, new HashMap<>(), updateModelGroupInput, listener, user); + } else { + wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + } else { + logException("Failed to get model group", e, log); + wrappedListener.onFailure(e); + } + })); + } catch (Exception e) { + logException("Failed to Update model group", e, log); + listener.onFailure(e); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java new file mode 100644 index 0000000000..e1583abb44 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -0,0 +1,399 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class UpdateModelTransportAction extends HandledTransportAction { + Client client; + ModelAccessControlHelper modelAccessControlHelper; + ConnectorAccessControlHelper connectorAccessControlHelper; + MLModelManager mlModelManager; + MLModelGroupManager mlModelGroupManager; + + @Inject + public UpdateModelTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ConnectorAccessControlHelper connectorAccessControlHelper, + ModelAccessControlHelper modelAccessControlHelper, + MLModelManager mlModelManager, + MLModelGroupManager mlModelGroupManager + ) { + super(MLUpdateModelAction.NAME, transportService, actionFilters, MLUpdateModelRequest::new); + this.client = client; + this.modelAccessControlHelper = modelAccessControlHelper; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlModelManager = mlModelManager; + this.mlModelGroupManager = mlModelGroupManager; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLUpdateModelRequest updateModelRequest = MLUpdateModelRequest.fromActionRequest(request); + MLUpdateModelInput updateModelInput = updateModelRequest.getUpdateModelInput(); + String modelId = updateModelInput.getModelId(); + User user = RestActionUtils.getUserContext(client); + + String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + FunctionName functionName = mlModel.getAlgorithm(); + MLModelState mlModelState = mlModel.getModelState(); + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + if (hasPermission) { + if (isModelDeployed(mlModelState)) { + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener, context); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "ML Model " + + modelId + + " is in deploying or deployed state, please undeploy the models first!", + RestStatus.FORBIDDEN + ) + ); + } + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model, model ID " + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + })); + } else { + actionListener + .onFailure( + new MLValidationException( + "User doesn't have privilege to perform this operation on this function category: " + + functionName.toString() + ) + ); + } + }, + e -> actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model to update with the provided model id: " + modelId, + RestStatus.NOT_FOUND + ) + ) + )); + } catch (Exception e) { + log.error("Failed to update ML model for " + modelId, e); + actionListener.onFailure(e); + } + } + + private void updateRemoteOrTextEmbeddingModel( + String modelId, + MLUpdateModelInput updateModelInput, + MLModel mlModel, + User user, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId()) + && !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null; + String relinkConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; + + if (mlModel.getAlgorithm() == TEXT_EMBEDDING) { + if (relinkConnectorId == null) { + updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener, context); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "Trying to update the connector or connector_id field on a local model", + RestStatus.BAD_REQUEST + ) + ); + } + } else { + // mlModel.getAlgorithm() == REMOTE + if (relinkConnectorId == null) { + updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener, context); + } else { + updateModelWithRelinkStandAloneConnector( + modelId, + newModelGroupId, + relinkConnectorId, + mlModel, + user, + updateModelInput, + actionListener, + context + ); + } + } + } + + private void updateModelWithRelinkStandAloneConnector( + String modelId, + String newModelGroupId, + String relinkConnectorId, + MLModel mlModel, + User user, + MLUpdateModelInput updateModelInput, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + if (Strings.hasLength(mlModel.getConnectorId())) { + connectorAccessControlHelper + .validateConnectorAccess(client, relinkConnectorId, ActionListener.wrap(hasRelinkConnectorPermission -> { + if (hasRelinkConnectorPermission) { + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + actionListener, + context + ); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to update the connector, connector id: " + relinkConnectorId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", relinkConnectorId, exception); + actionListener.onFailure(exception); + })); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "This remote does not have a connector_id field, maybe it uses an internal connector.", + RestStatus.BAD_REQUEST + ) + ); + } + } + + private void updateModelWithRegisteringToAnotherModelGroup( + String modelId, + String newModelGroupId, + User user, + MLUpdateModelInput updateModelInput, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); + if (newModelGroupId != null) { + modelAccessControlHelper.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasRelinkPermission -> { + if (hasRelinkPermission) { + mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { + updateRequestConstructor( + modelId, + newModelGroupId, + updateRequest, + updateModelInput, + newModelGroupResponse, + actionListener, + context + ); + }, + exception -> actionListener + .onFailure( + new OpenSearchStatusException( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " + + newModelGroupId, + RestStatus.NOT_FOUND + ) + ) + )); + } else { + actionListener + .onFailure( + new OpenSearchStatusException( + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " + + newModelGroupId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + })); + } else { + updateRequestConstructor(modelId, updateRequest, updateModelInput, actionListener, context); + } + } + + private void updateRequestConstructor( + String modelId, + UpdateRequest updateRequest, + MLUpdateModelInput updateModelInput, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + try { + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.docAsUpsert(true); + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); + } catch (IOException e) { + log.error("Failed to build update request."); + actionListener.onFailure(e); + } + } + + private void updateRequestConstructor( + String modelId, + String newModelGroupId, + UpdateRequest updateRequest, + MLUpdateModelInput updateModelInput, + GetResponse newModelGroupResponse, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + Map newModelGroupSourceMap = newModelGroupResponse.getSourceAsMap(); + String updatedVersion = incrementLatestVersion(newModelGroupSourceMap); + updateModelInput.setVersion(updatedVersion); + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + newModelGroupSourceMap, + newModelGroupId, + newModelGroupResponse.getSeqNo(), + newModelGroupResponse.getPrimaryTerm(), + Integer.parseInt(updatedVersion) + ); + try { + updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + updateRequest.docAsUpsert(true); + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); + }, e -> { + log + .error( + "Failed to register ML model with model ID {} to the new model group with model group ID {}", + modelId, + newModelGroupId + ); + actionListener.onFailure(e); + })); + } catch (IOException e) { + log.error("Failed to build update request."); + actionListener.onFailure(e); + } + } + + private ActionListener getUpdateResponseListener( + String modelId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.info("Model id:{} failed update", modelId); + actionListener.onResponse(updateResponse); + return; + } + log.info("Successfully update ML model with model ID {}", modelId); + actionListener.onResponse(updateResponse); + }, exception -> { + log.error("Failed to update ML model: " + modelId, exception); + actionListener.onFailure(exception); + }), context::restore); + } + + private String incrementLatestVersion(Map modelGroupSourceMap) { + return Integer.toString((int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1); + } + + private UpdateRequest createUpdateModelGroupRequest( + Map modelGroupSourceMap, + String modelGroupId, + long seqNo, + long primaryTerm, + int updatedVersion + ) { + modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); + modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateRequest updateModelGroupRequest = new UpdateRequest(); + + updateModelGroupRequest + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .doc(modelGroupSourceMap); + + return updateModelGroupRequest; + } + + private Boolean isModelDeployed(MLModelState mlModelState) { + return !mlModelState.equals(MLModelState.LOADED) + && !mlModelState.equals(MLModelState.LOADING) + && !mlModelState.equals(MLModelState.PARTIALLY_LOADED) + && !mlModelState.equals(MLModelState.DEPLOYED) + && !mlModelState.equals(MLModelState.DEPLOYING) + && !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 26a5a66de3..63be5e2423 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -16,6 +16,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; @@ -87,42 +88,72 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); - mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> { - FunctionName functionName = mlModel.getAlgorithm(); - mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName); - modelAccessControlHelper - .validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new MLValidationException("User Doesn't have privilege to perform this operation on this model") - ); - } else { - String requestId = mlPredictionTaskRequest.getRequestID(); - log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); - long startTime = System.nanoTime(); - mlPredictTaskRunner - .run( - functionName, - mlPredictionTaskRequest, - transportService, - ActionListener.runAfter(wrappedListener, () -> { - long endTime = System.nanoTime(); - double durationInMs = (endTime - startTime) / 1e6; - modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); - log.debug("completed predict request " + requestId + " for model " + modelId); - }) - ); - } - }, e -> { - log.error("Failed to Validate Access for ModelId " + modelId, e); - wrappedListener.onFailure(e); - })); - }, e -> { - log.error("Failed to find model " + modelId, e); - wrappedListener.onFailure(e); - })); + MLModel cachedMlModel = modelCacheHelper.getModelInfo(modelId); + ActionListener modelActionListener = new ActionListener<>() { + @Override + public void onResponse(MLModel mlModel) { + context.restore(); + modelCacheHelper.setModelInfo(modelId, mlModel); + FunctionName functionName = mlModel.getAlgorithm(); + mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName); + modelAccessControlHelper + .validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new MLValidationException("User Doesn't have privilege to perform this operation on this model") + ); + } else { + executePredict(mlPredictionTaskRequest, wrappedListener, modelId); + } + }, e -> { + log.error("Failed to Validate Access for ModelId " + modelId, e); + wrappedListener.onFailure(e); + })); + } + @Override + public void onFailure(Exception e) { + log.error("Failed to find model " + modelId, e); + wrappedListener.onFailure(e); + } + }; + + if (cachedMlModel != null) { + modelActionListener.onResponse(cachedMlModel); + } else { + // For multi-node cluster, the function name is null in cache, so should always get model first. + mlModelManager.getModel(modelId, modelActionListener); + } } } + + private void executePredict( + MLPredictionTaskRequest mlPredictionTaskRequest, + ActionListener wrappedListener, + String modelId + ) { + String requestId = mlPredictionTaskRequest.getRequestID(); + log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); + long startTime = System.nanoTime(); + // For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as + // TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get + // from model cache first. + FunctionName functionName = modelCacheHelper + .getOptionalFunctionName(modelId) + .orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm()); + mlPredictTaskRunner + .run( + // This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here + functionName, + mlPredictionTaskRequest, + transportService, + ActionListener.runAfter(wrappedListener, () -> { + long endTime = System.nanoTime(); + double durationInMs = (endTime - startTime) / 1e6; + modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); + log.debug("completed predict request " + requestId + " for model " + modelId); + }) + ); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 23006f5464..98927a5e5a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.regex.Pattern; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.ActionRequest; @@ -136,17 +137,76 @@ public TransportRegisterModelAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - User user = RestActionUtils.getUserContext(client); MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) { + mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId(); + registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided); + checkUserAccess(registerModelInput, listener, true); + } else { + doRegister(registerModelInput, listener); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + })); + } else { + checkUserAccess(registerModelInput, listener, false); + } + } + + private void checkUserAccess( + MLRegisterModelInput registerModelInput, + ActionListener listener, + Boolean isModelNameAlreadyExisting + ) { + User user = RestActionUtils.getUserContext(client); modelAccessControlHelper .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - log.error("You don't have permissions to perform this operation on this model."); - listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - } else { + if (access) { doRegister(registerModelInput, listener); + return; + } + // if the user does not have access, we need to check three more conditions before throwing exception. + // if we are checking the access based on the name provided in the input, we let user know the name is already used by a + // model group they do not have access to. + if (isModelNameAlreadyExisting) { + // This case handles when user is using the same pre-trained model already registered by another user on the cluster. + // The only way here is for the user to first create model group and use its ID in the request + if (registerModelInput.getUrl() == null + && registerModelInput.getFunctionName() != FunctionName.REMOTE + && registerModelInput.getConnectorId() == null) { + listener + .onFailure( + new IllegalArgumentException( + "Without a model group ID, the system will use the model name {" + + registerModelInput.getModelName() + + "} to create a new model group. However, this name is taken by another group with id {" + + registerModelInput.getModelGroupId() + + "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request." + ) + ); + } else { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + registerModelInput.getModelName() + + "} you provided is unavailable because it is used by another model group with id {" + + registerModelInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } + return; } + // if user does not have access to the model group ID provided in the input, we let user know they do not have access to the + // specified model group + listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); }, listener::onFailure)); } @@ -196,12 +256,14 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput); mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { registerModelInput.setModelGroupId(modelGroupId); + registerModelInput.setDoesVersionCreateModelGroup(true); registerModel(registerModelInput, listener); }, e -> { logException("Failed to create Model Group", e, log); listener.onFailure(e); })); } else { + registerModelInput.setDoesVersionCreateModelGroup(false); registerModel(registerModelInput, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index 01d8abb96c..a730de712f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -63,25 +63,52 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId(); + mlUploadInput.setModelGroupId(modelGroupIdOfTheNameProvided); + checkUserAccess(mlUploadInput, listener, true); + } else { + createModelGroup(mlUploadInput, listener); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + })); + } else { + checkUserAccess(mlUploadInput, listener, false); + } + } + private void checkUserAccess( + MLRegisterModelMetaInput mlUploadInput, + ActionListener listener, + Boolean isModelNameAlreadyExisting + ) { + + User user = RestActionUtils.getUserContext(client); modelAccessControlHelper.validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { + if (access) { + createModelGroup(mlUploadInput, listener); + return; + } + if (isModelNameAlreadyExisting) { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + mlUploadInput.getName() + + "} you provided is unavailable because it is used by another model group with id {" + + mlUploadInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } else { log.error("You don't have permissions to perform this operation on this model."); listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - } else { - if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) { - MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput); - mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { - mlUploadInput.setModelGroupId(modelGroupId); - registerModelMeta(mlUploadInput, listener); - }, e -> { - logException("Failed to create Model Group", e, log); - listener.onFailure(e); - })); - } else { - registerModelMeta(mlUploadInput, listener); - } } }, e -> { logException("Failed to validate model access", e, log); @@ -89,6 +116,23 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) { + MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput); + mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { + mlUploadInput.setModelGroupId(modelGroupId); + mlUploadInput.setDoesVersionCreateModelGroup(true); + registerModelMeta(mlUploadInput, listener); + }, e -> { + logException("Failed to create Model Group", e, log); + listener.onFailure(e); + })); + } else { + mlUploadInput.setDoesVersionCreateModelGroup(false); + registerModelMeta(mlUploadInput, listener); + } + } + private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterModelMetaInput mlUploadInput) { return MLRegisterModelGroupInput .builder() diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index 09fd80818f..b1096e7e38 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -19,6 +20,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; @@ -30,7 +32,6 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.builder.SearchSourceBuilder; @@ -64,35 +65,48 @@ public void validateConnectorAccess(Client client, String connectorId, ActionLis listener.onResponse(true); return; } - GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector connector = Connector.createConnector(parser); - boolean hasPermission = hasPermission(user, connector); - wrappedListener.onResponse(hasPermission); - } catch (Exception e) { - log.error("Failed to parse connector:" + connectorId); - wrappedListener.onFailure(e); - } - } else { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find connector:" + connectorId)); - } - }, e -> { - log.error("Fail to get connector", e); - wrappedListener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); - })); + getConnector(client, connectorId, ActionListener.wrap(connector -> { + boolean hasPermission = hasPermission(user, connector); + wrappedListener.onResponse(hasPermission); + }, e -> { wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to validate Access for connector:" + connectorId, e); listener.onFailure(e); } + } + + public boolean validateConnectorAccess(Client client, Connector connector) { + User user = RestActionUtils.getUserContext(client); + if (isAdmin(user) || accessControlNotEnabled(user)) { + return true; + } + return hasPermission(user, connector); + } + public void getConnector(Client client, String connectorId, ActionListener listener) { + GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + listener.onFailure(e); + } + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); + } + }, e -> { + log.error("Failed to get connector", e); + listener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); + })); } public boolean skipConnectorAccessControl(User user) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 19ca890df5..5fd7d71ce0 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -14,6 +14,7 @@ import java.util.stream.DoubleStream; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.Predictable; @@ -34,6 +35,7 @@ public class MLModelCache { private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLExecutable executor; private final Set targetWorkerNodes; private final Set workerNodes; + private MLModel modelInfo; private final Queue modelInferenceDurationQueue; private final Queue predictRequestDurationQueue; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU; @@ -77,12 +79,16 @@ public String[] getTargetWorkerNodes() { * @param isFromUndeploy */ public void removeWorkerNode(String nodeId, boolean isFromUndeploy) { - if ((deployToAllNodes != null && deployToAllNodes) || isFromUndeploy) { + if (this.isDeployToAllNodes() || isFromUndeploy) { targetWorkerNodes.remove(nodeId); } if (isFromUndeploy) deployToAllNodes = false; workerNodes.remove(nodeId); + // when the model is not deployed to any node, we should remove the modelInfo from cache + if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) { + modelInfo = null; + } } public void removeWorkerNodes(Set removedNodes, boolean isFromUndeploy) { @@ -92,6 +98,9 @@ public void removeWorkerNodes(Set removedNodes, boolean isFromUndeploy) if (isFromUndeploy) deployToAllNodes = false; workerNodes.removeAll(removedNodes); + if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) { + modelInfo = null; + } } /** @@ -112,6 +121,14 @@ public String[] getWorkerNodes() { return workerNodes.toArray(new String[0]); } + public void setModelInfo(MLModel modelInfo) { + this.modelInfo = modelInfo; + } + + public MLModel getCachedModelInfo() { + return modelInfo; + } + public void syncWorkerNode(Set workerNodes) { this.workerNodes.clear(); this.workerNodes.addAll(workerNodes); @@ -129,6 +146,7 @@ public void clear() { modelState = null; functionName = null; workerNodes.clear(); + modelInfo = null; modelInferenceDurationQueue.clear(); predictRequestDurationQueue.clear(); if (predictor != null) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 74dbc26d61..553ffeb664 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -18,6 +18,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -429,6 +430,21 @@ public boolean getDeployToAllNodes(String modelId) { return mlModelCache.isDeployToAllNodes(); } + public void setModelInfo(String modelId, MLModel mlModel) { + MLModelCache mlModelCache = modelCaches.get(modelId); + if (mlModelCache != null) { + mlModelCache.setModelInfo(mlModel); + } + } + + public MLModel getModelInfo(String modelId) { + MLModelCache mlModelCache = modelCaches.get(modelId); + if (mlModelCache == null) { + return null; + } + return mlModelCache.getCachedModelInfo(); + } + private MLModelCache getExistingModelCache(String modelId) { MLModelCache modelCache = modelCaches.get(modelId); if (modelCache == null) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index e833910d58..83523729e4 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -11,6 +11,8 @@ import java.util.HashSet; import java.util.Iterator; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -30,6 +32,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; @@ -146,15 +149,13 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us AccessMode modelAccessMode = input.getModelAccessMode(); Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles(); if (modelAccessMode == null) { - if (modelAccessMode == null) { - if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { - throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); - } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { - input.setModelAccessMode(AccessMode.RESTRICTED); - modelAccessMode = AccessMode.RESTRICTED; - } else { - input.setModelAccessMode(AccessMode.PRIVATE); - } + if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); + } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { + input.setModelAccessMode(AccessMode.RESTRICTED); + modelAccessMode = AccessMode.RESTRICTED; + } else { + input.setModelAccessMode(AccessMode.PRIVATE); } } if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode) @@ -184,20 +185,47 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us } public void validateUniqueModelGroupName(String name, ActionListener listener) throws IllegalArgumentException { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onResponse(null); + } else { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } + }), () -> context.restore()) + ); + } catch (Exception e) { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } + } - client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onResponse(null); + /** + * Get model group from model group index. + * + * @param modelGroupId model group id + * @param listener action listener + */ + public void getModelGroupResponse(String modelGroupId, ActionListener listener) { + GetRequest getRequest = new GetRequest(); + getRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + listener.onResponse(r); } else { - log.error("Failed to search model group index", e); - listener.onFailure(e); + listener.onFailure(new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId)); } - })); + }, e -> { listener.onFailure(e); })); } private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 1f766ae13a..dd1deac4ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -296,6 +296,11 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput log.debug("Index model meta doc successfully {}", modelName); wrappedListener.onResponse(response.getId()); }, e -> { + deleteOrUpdateModelGroup( + mlRegisterModelMetaInput.getModelGroupId(), + mlRegisterModelMetaInput.getDoesVersionCreateModelGroup(), + version + ); log.error("Failed to index model meta doc", e); wrappedListener.onFailure(e); })); @@ -328,10 +333,6 @@ public void registerMLRemoteModel( String modelGroupId = mlRegisterModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - indexRemoteModel(mlRegisterModelInput, mlTask, "1", listener); - } - client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { if (getModelGroupResponse.isExists()) { Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); @@ -399,9 +400,6 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa String modelGroupId = registerModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - uploadModel(registerModelInput, mlTask, "1"); - } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { @@ -723,7 +721,8 @@ private void registerModel( modelId, modelSizeInBytes, chunkFiles, - hashValue + hashValue, + version ); } else { deleteFileQuietly(file); @@ -735,7 +734,7 @@ private void registerModel( handleException(functionName, taskId, e); deleteFileQuietly(file); // remove model doc as failed to upload model - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); semaphore.release(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); })); @@ -743,7 +742,7 @@ private void registerModel( }, e -> { log.error("Failed to index chunk file", e); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); handleException(functionName, taskId, e); }) ); @@ -792,7 +791,8 @@ private void updateModelRegisterStateAsDone( String modelId, Long modelSizeInBytes, List chunkFiles, - String hashValue + String hashValue, + String version ) { FunctionName functionName = registerModelInput.getFunctionName(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); @@ -818,7 +818,7 @@ private void updateModelRegisterStateAsDone( }, e -> { log.error("Failed to update model", e); handleException(functionName, taskId, e); - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); })); } @@ -831,7 +831,7 @@ private void deployModelAfterRegistering(MLRegisterModelInput registerModelInput client.execute(MLDeployModelAction.INSTANCE, request, listener); } - private void deleteModel(String modelId) { + private void deleteModel(String modelId, MLRegisterModelInput registerModelInput, String modelVersion) { DeleteRequest deleteRequest = new DeleteRequest(); deleteRequest.index(ML_MODEL_INDEX).id(modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.delete(deleteRequest); @@ -840,6 +840,38 @@ private void deleteModel(String modelId) { .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) .setAbortOnVersionConflict(false); client.execute(DeleteByQueryAction.INSTANCE, deleteChunksRequest); + deleteOrUpdateModelGroup(registerModelInput.getModelGroupId(), registerModelInput.getDoesVersionCreateModelGroup(), modelVersion); + } + + private void deleteOrUpdateModelGroup(String modelGroupID, Boolean doesVersionCreateModelGroup, String modelVersion) { + // This checks if model group is created when registering the version. If yes, model group is deleted since the version registration + // had failed. Else model group latest version is decremented by 1 + if (doesVersionCreateModelGroup) { + DeleteRequest deleteModelGroupRequest = new DeleteRequest(); + deleteModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupID).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteModelGroupRequest); + } else { + updateLatestVersionInModelGroup( + modelGroupID, + Integer.parseInt(modelVersion) - 1, + ActionListener + .wrap(r -> log.debug("model group updated, response {}", r), e -> log.error("Failed to update model group", e)) + ); + } + } + + private void updateLatestVersionInModelGroup(String modelGroupID, Integer latestVersion, ActionListener listener) { + Map updatedFields = new HashMap<>(); + updatedFields.put(MLModelGroup.LATEST_VERSION_FIELD, latestVersion); + updatedFields.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_GROUP_INDEX, modelGroupID); + updateRequest.doc(updatedFields); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + listener.onFailure(e); + } } private void handleException(FunctionName functionName, String taskId, Exception e) { @@ -910,8 +942,8 @@ public void deployModel( CLUSTER_SERVICE, clusterService ); - // deploy remote model or model trained by built-in algorithm like kmeans - if (mlModel.getConnector() != null) { + // deploy remote model with internal connector or model trained by built-in algorithm like kmeans + if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { setupPredictable(modelId, mlModel, params); wrappedListener.onResponse("successful"); return; @@ -920,6 +952,7 @@ public void deployModel( GetRequest getConnectorRequest = new GetRequest(); FetchSourceContext fetchContext = new FetchSourceContext(true, null, null); getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext); + // get connector and deploy remote model with standalone connector client.get(getConnectorRequest, ActionListener.wrap(getResponse -> { if (getResponse != null && getResponse.isExists()) { try ( diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 70e87eb860..ad3b4dfc44 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -53,6 +53,7 @@ import org.opensearch.ml.action.models.DeleteModelTransportAction; import org.opensearch.ml.action.models.GetModelTransportAction; import org.opensearch.ml.action.models.SearchModelTransportAction; +import org.opensearch.ml.action.models.UpdateModelTransportAction; import org.opensearch.ml.action.prediction.TransportPredictionTaskAction; import org.opensearch.ml.action.profile.MLProfileAction; import org.opensearch.ml.action.profile.MLProfileTransportAction; @@ -100,6 +101,7 @@ import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; @@ -166,6 +168,7 @@ import org.opensearch.ml.rest.RestMLTrainingAction; import org.opensearch.ml.rest.RestMLUndeployModelAction; import org.opensearch.ml.rest.RestMLUpdateConnectorAction; +import org.opensearch.ml.rest.RestMLUpdateModelAction; import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; import org.opensearch.ml.rest.RestMemoryCreateConversationAction; @@ -282,6 +285,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(MLUndeployModelsAction.INSTANCE, TransportUndeployModelsAction.class), new ActionHandler<>(MLRegisterModelMetaAction.INSTANCE, TransportRegisterModelMetaAction.class), new ActionHandler<>(MLUploadModelChunkAction.INSTANCE, TransportUploadModelChunkAction.class), + new ActionHandler<>(MLUpdateModelAction.INSTANCE, UpdateModelTransportAction.class), new ActionHandler<>(MLForwardAction.INSTANCE, TransportForwardAction.class), new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class), new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class), @@ -536,6 +540,7 @@ public List getRestHandlers( RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction(); RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); + RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); @@ -557,6 +562,7 @@ public List getRestHandlers( restMLGetModelAction, restMLDeleteModelAction, restMLSearchModelAction, + restMLUpdateModelAction, restMLGetTaskAction, restMLDeleteTaskAction, restMLSearchTaskAction, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java index a74ed27ecc..b6e3822318 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Locale; +import org.opensearch.OpenSearchParseException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; @@ -43,12 +44,7 @@ public String getName() { @Override public List routes() { return ImmutableList - .of( - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/connectors/_update/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID) - ) - ); + .of(new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/connectors/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID))); } @Override @@ -65,14 +61,17 @@ private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOExcept } if (!request.hasContent()) { - throw new IOException("Failed to update connector: Request body is empty"); + throw new OpenSearchParseException("Failed to update connector: Request body is empty"); } String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); - XContentParser parser = request.contentParser(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - - return MLUpdateConnectorRequest.parse(parser, connectorId); + try { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return MLUpdateConnectorRequest.parse(parser, connectorId); + } catch (IllegalStateException illegalStateException) { + throw new OpenSearchParseException(illegalStateException.getMessage()); + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java new file mode 100644 index 0000000000..79959cbf26 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMLUpdateModelAction extends BaseRestHandler { + + private static final String ML_UPDATE_MODEL_ACTION = "ml_update_model_action"; + + @Override + public String getName() { + return ML_UPDATE_MODEL_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/models/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLUpdateModelRequest updateModelRequest = getRequest(request); + return channel -> client.execute(MLUpdateModelAction.INSTANCE, updateModelRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLUpdateModelRequest from a RestRequest + * + * @param request RestRequest + * @return MLUpdateModelRequest + */ + private MLUpdateModelRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new OpenSearchParseException("Model update request has empty body"); + } + + String modelId = getParameterId(request, PARAMETER_MODEL_ID); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + try { + MLUpdateModelInput input = MLUpdateModelInput.parse(parser); + // Model ID can only be set here. Model version can only be set automatically. + input.setModelId(modelId); + input.setVersion(null); + return new MLUpdateModelRequest(input); + } catch (IllegalStateException e) { + throw new OpenSearchParseException(e.getMessage()); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 4d62bb504f..bf200a3b02 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -127,7 +127,8 @@ private MLCommonsSettings() {} .of( "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", - "^https://api\\.cohere\\.ai/.*$" + "^https://api\\.cohere\\.ai/.*$", + "^https://bedrock-runtime\\..*[a-z0-9-]\\.amazonaws\\.com/.*$" ), Function.identity(), Setting.Property.NodeScope, diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index 3977c5e932..0a1a00ac4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -18,6 +18,7 @@ public class MLFeatureEnabledSetting { public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); + clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index e6b6be2c62..348618773a 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -119,8 +119,6 @@ public void dispatchTask( ActionListener listener ) { String modelId = request.getModelId(); - MLInput input = request.getMlInput(); - FunctionName algorithm = input.getAlgorithm(); try { ActionListener actionListener = ActionListener.wrap(node -> { if (clusterService.localNode().getId().equals(node.getId())) { @@ -133,9 +131,9 @@ public void dispatchTask( transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); } }, e -> { listener.onFailure(e); }); - String[] workerNodes = mlModelManager.getWorkerNodes(modelId, algorithm, true); + String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true); if (workerNodes == null || workerNodes.length == 0) { - if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) { listener .onFailure( new IllegalArgumentException( @@ -144,7 +142,7 @@ public void dispatchTask( ); return; } else { - workerNodes = nodeHelper.getEligibleNodeIds(algorithm); + workerNodes = nodeHelper.getEligibleNodeIds(functionName); } } mlTaskDispatcher.dispatchPredictTask(workerNodes, actionListener); @@ -215,9 +213,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe FunctionName algorithm = mlInput.getAlgorithm(); // run predict if (modelId != null) { - try { - Predictable predictor = mlModelManager.getPredictor(modelId); - if (predictor != null) { + Predictable predictor = mlModelManager.getPredictor(modelId); + if (predictor != null) { + try { if (!predictor.isModelReady()) { throw new IllegalArgumentException("Model not ready: " + modelId); } @@ -231,11 +229,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe MLTaskResponse response = MLTaskResponse.builder().output(output).build(); internalListener.onResponse(response); return; - } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { - throw new IllegalArgumentException("Model not ready to be used: " + modelId); + } catch (Exception e) { + handlePredictFailure(mlTask, internalListener, e, false, modelId); + return; } - } catch (Exception e) { - handlePredictFailure(mlTask, internalListener, e, false, modelId); + } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + throw new IllegalArgumentException("Model not ready to be used: " + modelId); } // search model by model id. @@ -254,6 +253,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe GetResponse getResponse = r; String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString(); MLModel mlModel = MLModel.parse(xContentParser, algorithmName); + mlModel.setModelId(modelId); User resourceUser = mlModel.getUser(); User requestUser = getUserContext(client); if (!checkUserPermissions(requestUser, resourceUser, modelId)) { @@ -265,7 +265,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe return; } // run predict - mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync()); + if (mlTaskManager.contains(mlTask.getTaskId())) { + mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync()); + } MLOutput output = mlEngine.predict(mlInput, mlModel); if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index e3a0cdc058..977ae66603 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -44,6 +44,7 @@ import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -180,7 +181,7 @@ public void testDeleteConnector_BlockedByModel() throws IOException { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "1 models are still using this connector, please delete or update the models first!", + "1 models are still using this connector, please delete or update the models first: [model_ID]", argumentCaptor.getValue().getMessage() ); } @@ -291,8 +292,17 @@ private SearchResponse getEmptySearchResponse() { return searchResponse; } - private SearchResponse getNonEmptySearchResponse() { + private SearchResponse getNonEmptySearchResponse() throws IOException { SearchHit[] hits = new SearchHit[1]; + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_ID\",\n" + + " \"name\": \"test_model\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent)); + hits[0] = model; SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); SearchResponseSections searchSections = new SearchResponseSections( searchHits, diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index a7fb34a4b5..c2cbf81cf1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -150,7 +150,7 @@ public void testGetConnector_IndexNotFoundException() { getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Fail to find connector", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); } public void testGetConnector_RuntimeException() { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java similarity index 69% rename from plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index fc6020474a..e1bbcfa881 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -7,20 +7,20 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; import java.util.List; -import java.util.Map; +import java.util.UUID; import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -41,7 +41,14 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.TestHelper; @@ -54,8 +61,9 @@ import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; -public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { +public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private UpdateConnectorTransportAction transportUpdateConnectorAction; @@ -100,6 +108,8 @@ public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { private SearchResponse searchResponse; + private MLEngine mlEngine; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @@ -122,7 +132,12 @@ public void setup() throws IOException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); String connector_id = "test_connector_id"; - Map updateContent = Map.of("version", "2", "description", "updated description"); + MLCreateConnectorInput updateContent = MLCreateConnectorInput + .builder() + .updateConnector(true) + .version("2") + .description("updated description") + .build(); when(updateRequest.getConnectorId()).thenReturn(connector_id); when(updateRequest.getUpdateContent()).thenReturn(updateContent); @@ -139,25 +154,57 @@ public void setup() throws IOException { SearchResponse.Clusters.EMPTY ); + Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); + transportUpdateConnectorAction = new UpdateConnectorTransportAction( transportService, actionFilters, client, connectorAccessControlHelper, - mlModelManager + mlModelManager, + settings, + clusterService, + mlEngine ); when(mlModelManager.getAllModelIds()).thenReturn(new String[] {}); shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); - } - public void test_execute_connectorAccessControl_success() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); + ActionListener listener = invocation.getArgument(2); + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(ImmutableMap.of("api_key", "credential_value")) + .parameters(ImmutableMap.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); + // Connector connector = mock(HttpConnector.class); + // doNothing().when(connector).update(any(), any()); + listener.onResponse(connector); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); + } + + @Test + public void test_execute_connectorAccessControl_success() { + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -175,12 +222,9 @@ public void test_execute_connectorAccessControl_success() { verify(actionListener).onResponse(updateResponse); } + @Test public void test_execute_connectorAccessControl_NoPermission() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(false); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(false).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -191,12 +235,11 @@ public void test_execute_connectorAccessControl_NoPermission() { ); } + @Test public void test_execute_connectorAccessControl_AccessError() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Connector Access Control Error")); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doThrow(new RuntimeException("Connector Access Control Error")) + .when(connectorAccessControlHelper) + .validateConnectorAccess(any(Client.class), any(Connector.class)); transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -204,10 +247,11 @@ public void test_execute_connectorAccessControl_AccessError() { assertEquals("Connector Access Control Error", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_connectorAccessControl_Exception() { doThrow(new RuntimeException("exception in access control")) .when(connectorAccessControlHelper) - .validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + .validateConnectorAccess(any(Client.class), any(Connector.class)); transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -215,12 +259,9 @@ public void test_execute_connectorAccessControl_Exception() { assertEquals("exception in access control", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_UpdateWrongStatus() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -239,12 +280,9 @@ public void test_execute_UpdateWrongStatus() { verify(actionListener).onResponse(updateResponse); } + @Test public void test_execute_UpdateException() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -264,12 +302,9 @@ public void test_execute_UpdateException() { assertEquals("update document failure", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_SearchResponseNotEmpty() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -280,15 +315,14 @@ public void test_execute_SearchResponseNotEmpty() { transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage()); + assertTrue( + argumentCaptor.getValue().getMessage().contains("1 models are still using this connector, please undeploy the models first") + ); } + @Test public void test_execute_SearchResponseError() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -302,12 +336,38 @@ public void test_execute_SearchResponseError() { assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_SearchIndexNotFoundError() { + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); + doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); + ActionListener listener = invocation.getArgument(2); + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(ImmutableMap.of("api_key", "credential_value")) + .parameters(ImmutableMap.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); + // Connector connector = mock(HttpConnector.class); + // doNothing().when(connector).update(any(), any()); + listener.onResponse(connector); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index 1a67977291..a99488d0e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -38,6 +38,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; @@ -367,6 +368,20 @@ public void test_FailedToGetModelGroupException() { assertEquals("Failed to get model group", argumentCaptor.getValue().getMessage()); } + public void test_ModelGroupIndexNotFoundException() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IndexNotFoundException("Fail to find model group")); + return null; + }).when(client).get(any(), any()); + + MLUpdateModelGroupRequest actionRequest = prepareRequest(null, AccessMode.RESTRICTED, null); + transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find model group", argumentCaptor.getValue().getMessage()); + } + public void test_FailedToUpdatetModelGroupException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -414,15 +429,16 @@ public void test_ModelGroupNameNotUnique() throws IOException { } public void test_ExceptionSecurityDisabledCluster() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule - .expectMessage( - "You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster." - ); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); MLUpdateModelGroupRequest actionRequest = prepareRequest(null, null, true); transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster.", + argumentCaptor.getValue().getMessage() + ); } private MLUpdateModelGroupRequest prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java new file mode 100644 index 0000000000..85dfaa552a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -0,0 +1,847 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.models; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class UpdateModelTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + Task task; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + MLUpdateModelInput mockUpdateModelInput; + + @Mock + MLUpdateModelRequest mockUpdateModelRequest; + + @Mock + MLModel mockModel; + + @Mock + MLModelManager mlModelManager; + + @Mock + MLModelGroupManager mlModelGroupManager; + + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private ShardId shardId; + + UpdateResponse updateResponse; + + UpdateModelTransportAction transportUpdateModelAction; + + MLUpdateModelRequest updateLocalModelRequest; + + MLUpdateModelInput updateLocalModelInput; + + MLUpdateModelRequest updateRemoteModelRequest; + + MLUpdateModelInput updateRemoteModelInput; + + MLModel mlModelWithNullFunctionName; + + MLModel localModel; + + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + updateLocalModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .build(); + updateLocalModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateLocalModelInput).build(); + updateRemoteModelInput = MLUpdateModelInput + .builder() + .modelId("test_model_id") + .name("updated_test_name") + .description("updated_test_description") + .modelGroupId("updated_test_model_group_id") + .connectorId("updated_test_connector_id") + .build(); + updateRemoteModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); + + mlModelWithNullFunctionName = MLModel + .builder() + .modelId("test_model_id") + .name("test_name") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .build(); + + Settings settings = Settings.builder().build(); + + transportUpdateModelAction = spy( + new UpdateModelTransportAction( + transportService, + actionFilters, + client, + connectorAccessControlHelper, + modelAccessControlHelper, + mlModelManager, + mlModelGroupManager + ) + ); + + localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }) + .when(connectorAccessControlHelper) + .validateConnectorAccess(any(Client.class), eq("updated_test_connector_id"), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(localModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("updated_test_model_group_id") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + } + + @Test + public void testUpdateLocalModelSuccess() { + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelStateLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateLoadingException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.LOADING).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStatePartiallyLoadedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_LOADED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateDeployedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.DEPLOYED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStateDeployingException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.DEPLOYING).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelStatePartiallyDeployedException() { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.PARTIALLY_DEPLOYED).when(mockModel).getModelState(); + + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "ML Model mockId is in deploying or deployed state, please undeploy the models first!", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() { + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithLocalInformationSuccess() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationSuccess() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithNoStandAloneConnectorFound() { + MLModel remoteModelWithInternalConnector = prepareUnsupportedMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModelWithInternalConnector); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "This remote does not have a connector_id field, maybe it uses an internal connector.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlNoPermission() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(false); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to update the connector, connector id: updated_test_connector_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlOtherException() { + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(remoteModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener + .onFailure( + new RuntimeException("Any other connector access control Exception occurred. Please check log for more details.") + ); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other connector access control Exception occurred. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model, model ID test_model_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during update the model. Please check log for more details." + ) + ); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during update the model. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlNoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID updated_test_model_group_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details." + ) + ); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRegisterToNewModelGroupNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new MLResourceNotFoundException("Model group not found with MODEL_GROUP_ID: updated_test_model_group_id")); + return null; + }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: updated_test_model_group_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(null); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model to update with the provided model id: test_model_id", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateModelWithFunctionNameFieldNotFound() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mlModelWithNullFunctionName); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + } + + @Test + public void testUpdateLocalModelWithRemoteInformation() { + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Trying to update the connector or connector_id field on a local model", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateLocalModelWithUnsupportedFunction() { + MLModel localModelWithUnsupportedFunction = prepareUnsupportedMLModel(FunctionName.KMEANS); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(localModelWithUnsupportedFunction); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this function category: KMEANS", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRequestDocIOException() throws IOException { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); + + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IOException { + doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); + doReturn("mockId").when(mockUpdateModelInput).getModelId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockModel); + return null; + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + + doReturn("test_model_group_id").when(mockModel).getModelGroupId(); + doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); + doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); + + doReturn("mockUpdateModelGroupId").when(mockUpdateModelInput).getModelGroupId(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("mockUpdateModelGroupId"), any(), isA(ActionListener.class)); + + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("updated_test_model_group_id") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); + + doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testGetUpdateResponseListenerWithVersionBumpWrongStatus() { + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testGetUpdateResponseListenerWithVersionBumpOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testGetUpdateResponseListenerWrongStatus() { + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testGetUpdateResponseListenerOtherException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." + ) + ); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + // TODO: Add UT to make sure that version incremented successfully. + + private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgumentException { + MLModel mlModel; + switch (functionName) { + case TEXT_EMBEDDING: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.TEXT_EMBEDDING) + .build(); + return mlModel; + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connectorId("test_connector_id") + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.TEXT_EMBEDDING and FunctionName.REMOTE"); + } + } + + private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws IllegalArgumentException { + MLModel mlModel; + switch (unsupportedCase) { + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) + .build(); + return mlModel; + case KMEANS: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.KMEANS) + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS"); + } + } + + private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOException { + XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 1a8384f45f..ac1f09dea1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -18,9 +18,11 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import java.io.IOException; import java.util.List; import java.util.Map; +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -30,6 +32,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -61,6 +64,9 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -144,7 +150,7 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { private ConnectorAccessControlHelper connectorAccessControlHelper; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); settings = Settings .builder() @@ -199,6 +205,13 @@ public void setup() { return null; }).when(mlTaskDispatcher).dispatch(any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + when(clusterService.localNode()).thenReturn(node2); when(node2.getId()).thenReturn("node2Id"); @@ -461,6 +474,97 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi ); } + public void test_ModelNameAlreadyExists() throws IOException { + when(node1.getId()).thenReturn("NodeId1"); + when(node2.getId()).thenReturn("NodeId2"); + MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); + doAnswer(invocation -> { + ActionListenerResponseHandler handler = invocation.getArgument(3); + handler.handleResponse(forwardResponse); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", null), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException { + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .modelName("huggingface/sentence-transformers/all-MiniLM-L12-v2") + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .version("1") + .build(); + + transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Without a model group ID, the system will use the model name {huggingface/sentence-transformers/all-MiniLM-L12-v2} to create a new model group. However, this name is taken by another group with id {model_group_ID} you can't access. To register this pre-trained model, create a new model group and use its ID in your request.", + argumentCaptor.getValue().getMessage() + + ); + } + + public void test_FailureWhenSearchingModelGroupName() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Runtime exception")); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Runtime exception", argumentCaptor.getValue().getMessage()); + } + + public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { + + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "The name {Test Model} you provided is unavailable because it is used by another model group with id {model_group_ID} to which you do not have access. Please provide a different name.", + argumentCaptor.getValue().getMessage() + ); + } + private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) { MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -485,4 +589,22 @@ private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) { return new MLRegisterModelRequest(registerModelInput); } + private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { + + SearchResponse searchResponse = mock(SearchResponse.class); + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"access\": \"public\",\n" + + " \"latest_version\": 0,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_group_ID\",\n" + + " \"name\": \"Test Model\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index 26b2f3f091..f7eb64c8eb 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -7,13 +7,18 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; + +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; @@ -30,6 +35,9 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -67,7 +75,7 @@ public class TransportRegisterModelMetaActionTests extends OpenSearchTestCase { private ModelAccessControlHelper modelAccessControlHelper; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -93,6 +101,13 @@ public void setup() { return null; }).when(mlModelManager).registerModelMeta(any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); } @@ -169,10 +184,64 @@ public void test_ValidationFailedException() { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } + public void testDoExecute_ModelNameAlreadyExists() throws IOException { + + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "The name {Test Model} you provided is unavailable because it is used by another model group with id {model_group_ID} to which you do not have access. Please provide a different name.", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_FailureWhenSearchingModelGroupName() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Runtime exception")); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Runtime exception", argumentCaptor.getValue().getMessage()); + } + private MLRegisterModelMetaRequest prepareRequest(String modelGroupID) { MLRegisterModelMetaInput input = MLRegisterModelMetaInput .builder() - .name("Model Name") + .name("Test Model") .modelGroupId(modelGroupID) .description("Custom Model Test") .modelFormat(MLModelFormat.TORCH_SCRIPT) @@ -195,4 +264,22 @@ private MLRegisterModelMetaRequest prepareRequest(String modelGroupID) { return new MLRegisterModelMetaRequest(input); } + private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { + + SearchResponse searchResponse = mock(SearchResponse.class); + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"access\": \"public\",\n" + + " \"latest_version\": 0,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_group_ID\",\n" + + " \"name\": \"Test Model\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 7f48d9f32c..30c9f6191c 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -22,6 +22,7 @@ import org.junit.Before; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -42,7 +43,6 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.ConnectorProtocols; import org.opensearch.ml.common.connector.HttpConnector; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -207,7 +207,7 @@ public void test_validateConnectorAccess_connectorNotFound_return_false() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); - verify(actionListener, times(1)).onFailure(any(MLResourceNotFoundException.class)); + verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } public void test_validateConnectorAccess_searchConnectorException_return_false() { @@ -222,7 +222,7 @@ public void test_validateConnectorAccess_searchConnectorException_return_false() threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); - verify(actionListener).onFailure(any(IllegalStateException.class)); + verify(actionListener).onFailure(any(OpenSearchStatusException.class)); } public void test_skipConnectorAccessControl_userIsNull_return_true() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java index 78d67b3ce1..603c315a12 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java @@ -5,6 +5,7 @@ package org.opensearch.ml.model; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -26,6 +27,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -251,6 +253,7 @@ public void testSyncWorkerNodes_ModelState() { cacheHelper.syncWorkerNodes(modelWorkerNodes); assertEquals(2, cacheHelper.getAllModels().length); assertEquals(0, cacheHelper.getWorkerNodes(modelId2).length); + assertNull(cacheHelper.getModelInfo(modelId2)); assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(modelId)); } @@ -323,6 +326,15 @@ public void test_removeWorkerNodes_with_deployToAllNodesStatus_isTrue() { cacheHelper.removeWorkerNodes(ImmutableSet.of(nodeId), false); cacheHelper.removeWorkerNode(modelId, nodeId, false); assertEquals(0, cacheHelper.getWorkerNodes(modelId).length); + assertNull(cacheHelper.getModelInfo(modelId)); + } + + public void test_setModelInfo_success() { + cacheHelper.initModelState(modelId, MLModelState.DEPLOYED, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true); + MLModel model = mock(MLModel.class); + when(model.getModelId()).thenReturn("mockId"); + cacheHelper.setModelInfo(modelId, model); + assertEquals("mockId", cacheHelper.getModelInfo(modelId).getModelId()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index f7eb759026..ccedef9bc1 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -23,35 +23,37 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; -import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; public class MLModelGroupManagerTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - @Mock - private TransportService transportService; - @Mock private MLIndicesHandler mlIndicesHandler; @@ -61,26 +63,23 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { @Mock private ThreadPool threadPool; - @Mock - private Task task; - @Mock private Client client; - @Mock - private ActionFilters actionFilters; @Mock private ActionListener actionListener; + @Mock + private ActionListener modelGroupListener; + @Mock private IndexResponse indexResponse; ThreadContext threadContext; - private TransportRegisterModelGroupAction transportRegisterModelGroupAction; - @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock private MLModelGroupManager mlModelGroupManager; @@ -335,6 +334,61 @@ public void test_ExceptionInitModelGroupIndexIfAbsent() { assertEquals("Index Not Found", argumentCaptor.getValue().getMessage()); } + public void test_SuccessGetModelGroup() throws IOException { + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("testModelGroupID") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + verify(modelGroupListener).onResponse(getResponse); + } + + public void test_OtherExceptionGetModelGroup() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new RuntimeException("Any other Exception occurred during getting the model group. Please check log for more details.") + ); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(modelGroupListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during getting the model group. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_NotFoundGetModelGroup() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(modelGroupListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model group with ID: testModelGroupID", argumentCaptor.getValue().getMessage()); + } + private MLRegisterModelGroupInput prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { return MLRegisterModelGroupInput .builder() @@ -363,4 +417,10 @@ private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOE return searchResponse; } + private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOException { + XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java index 814402fb66..1c6a3d2ae7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -14,7 +14,6 @@ import static org.mockito.Mockito.when; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; -import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -25,6 +24,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; @@ -98,8 +98,8 @@ public void testRoutes() { assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); - assertEquals(RestRequest.Method.POST, route.getMethod()); - assertEquals("/_plugins/_ml/connectors/_update/{connector_id}", route.getPath()); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/connectors/{connector_id}", route.getPath()); } public void testUpdateConnectorRequest() throws Exception { @@ -109,12 +109,19 @@ public void testUpdateConnectorRequest() throws Exception { verify(client, times(1)).execute(eq(MLUpdateConnectorAction.INSTANCE), argumentCaptor.capture(), any()); MLUpdateConnectorRequest updateConnectorRequest = argumentCaptor.getValue(); assertEquals("test_connectorId", updateConnectorRequest.getConnectorId()); - assertEquals("This is test description", updateConnectorRequest.getUpdateContent().get("description")); - assertEquals("2", updateConnectorRequest.getUpdateContent().get("version")); + assertEquals("This is test description", updateConnectorRequest.getUpdateContent().getDescription()); + assertEquals("2", updateConnectorRequest.getUpdateContent().getVersion()); + } + + public void testUpdateConnectorRequestWithParsingException() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Can't get text on a VALUE_NULL"); + RestRequest request = getRestRequestWithNullValue(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); } public void testUpdateConnectorRequestWithEmptyContent() throws Exception { - exceptionRule.expect(IOException.class); + exceptionRule.expect(OpenSearchParseException.class); exceptionRule.expectMessage("Failed to update connector: Request body is empty"); RestRequest request = getRestRequestWithEmptyContent(); restMLUpdateConnectorAction.handleRequest(request, channel, client); @@ -151,6 +158,20 @@ private RestRequest getRestRequest() { return request; } + private RestRequest getRestRequestWithNullValue() { + RestRequest.Method method = RestRequest.Method.POST; + String requestContent = "{\"version\":\"2\",\"description\":null}"; + Map params = new HashMap<>(); + params.put("connector_id", "test_connectorId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + private RestRequest getRestRequestWithEmptyContent() { RestRequest.Method method = RestRequest.Method.POST; Map params = new HashMap<>(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java new file mode 100644 index 0000000000..28687d1c9c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.model.MLUpdateModelAction; +import org.opensearch.ml.common.transport.model.MLUpdateModelInput; +import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMLUpdateModelActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMLUpdateModelAction restMLUpdateModelAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + restMLUpdateModelAction = new RestMLUpdateModelAction(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLUpdateModelAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + @Test + public void testConstructor() { + RestMLUpdateModelAction UpdateModelAction = new RestMLUpdateModelAction(); + assertNotNull(UpdateModelAction); + } + + @Test + public void testGetName() { + String actionName = restMLUpdateModelAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_model_action", actionName); + } + + @Test + public void testRoutes() { + List routes = restMLUpdateModelAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/models/{model_id}", route.getPath()); + } + + @Test + public void testUpdateModelRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLUpdateModelAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateModelRequest.class); + verify(client, times(1)).execute(eq(MLUpdateModelAction.INSTANCE), argumentCaptor.capture(), any()); + MLUpdateModelInput updateModelInput = argumentCaptor.getValue().getUpdateModelInput(); + assertEquals("testModelName", updateModelInput.getName()); + assertEquals("This is test description", updateModelInput.getDescription()); + } + + @Test + public void testUpdateModelRequestWithEmptyContent() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Model update request has empty body"); + RestRequest request = getRestRequestWithEmptyContent(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelRequestWithNullModelId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain model_id"); + RestRequest request = getRestRequestWithNullModelId(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + @Test + public void testUpdateModelRequestWithNullField() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Can't get text on a VALUE_NULL"); + RestRequest request = getRestRequestWithNullField(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.PUT; + final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.PUT; + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullModelId() { + RestRequest.Method method = RestRequest.Method.PUT; + final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); + String requestContent = new Gson().toJson(modelContent); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullField() { + RestRequest.Method method = RestRequest.Method.PUT; + String requestContent = "{\"name\":\"testModelName\",\"description\":null}"; + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 5f18d974a7..0d0c594458 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -212,7 +212,7 @@ public void setup() throws IOException { public void testExecuteTask_OnLocalNode() { setupMocks(true, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -220,10 +220,22 @@ public void testExecuteTask_OnLocalNode() { verify(mlTaskManager).remove(anyString()); } + public void testExecuteTask_OnLocalNode_RemoteModel() { + setupMocks(true, false, false, false); + + taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue().getMessage().contains("Model not ready yet.")); + verify(mlTaskManager, never()).add(any(MLTask.class)); + verify(client, never()).get(any(), any()); + } + public void testExecuteTask_OnLocalNode_QueryInput() { setupMocks(true, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -234,7 +246,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() { public void testExecuteTask_OnLocalNode_QueryInput_Failure() { setupMocks(true, true, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager, never()).add(any(MLTask.class)); @@ -245,7 +257,7 @@ public void testExecuteTask_NoPermission() { setupMocks(true, true, false, false); threadContext.stashContext(); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test_user|test_role|test_tenant"); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlTaskManager).add(any(MLTask.class)); verify(mlTaskManager).remove(anyString()); verify(client).get(any(), any()); @@ -256,14 +268,14 @@ public void testExecuteTask_NoPermission() { public void testExecuteTask_OnRemoteNode() { setupMocks(false, false, false, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); } public void testExecuteTask_OnLocalNode_GetModelFail() { setupMocks(true, false, true, false); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -277,7 +289,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { setupMocks(true, false, false, false); requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build(); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); @@ -291,7 +303,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() { public void testExecuteTask_OnLocalNode_NullGetResponse() { setupMocks(true, false, false, true); - taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener); + taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); // verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); verify(mlTaskManager).add(any(MLTask.class)); diff --git a/release-notes/opensearch-ml-common.release-notes-2.11.0.0.md b/release-notes/opensearch-ml-common.release-notes-2.11.0.0.md index 9902d641b9..564a8df998 100644 --- a/release-notes/opensearch-ml-common.release-notes-2.11.0.0.md +++ b/release-notes/opensearch-ml-common.release-notes-2.11.0.0.md @@ -36,6 +36,7 @@ Compatible with OpenSearch 2.11.0 * fix no worker node exception for remote embedding model ([#1482](https://github.com/opensearch-project/ml-commons/pull/1482)) * fix for delete model group API throwing incorrect error when model index not created ([#1485](https://github.com/opensearch-project/ml-commons/pull/1485)) * fix no worker node error on multi-node cluster ([#1487](https://github.com/opensearch-project/ml-commons/pull/1487)) +* Fix prompt passing for Bedrock by passing a single string prompt for Bedrock models. ([#1490](https://github.com/opensearch-project/ml-commons/pull/1490)) ### Documentation From 5f076bca74d957364bcd3b3c40203f9457fe437a Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Tue, 30 Jan 2024 12:11:53 -0800 Subject: [PATCH 2/4] Allow model setting to transport from rest Signed-off-by: Owais Kazi --- .../action/register/TransportRegisterModelAction.java | 10 ++++++++++ .../opensearch/ml/rest/RestMLRegisterModelAction.java | 9 --------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index ca226f0251..64e7cc308c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; @@ -89,6 +90,7 @@ public class TransportRegisterModelAction extends HandledTransportAction trustedConnectorEndpointsRegex; ModelAccessControlHelper modelAccessControlHelper; + private volatile boolean isModelUrlAllowed; ConnectorAccessControlHelper connectorAccessControlHelper; MLModelGroupManager mlModelGroupManager; @@ -132,6 +134,9 @@ public TransportRegisterModelAction( trustedUrlRegex = ML_COMMONS_TRUSTED_URL_REGEX.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_URL_REGEX, it -> trustedUrlRegex = it); + isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it); + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -169,6 +174,11 @@ private void checkUserAccess( Boolean isModelNameAlreadyExisting ) { User user = RestActionUtils.getUserContext(client); + if (registerModelInput.getUrl() != null && !isModelUrlAllowed) { + throw new IllegalArgumentException( + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." + ); + } modelAccessControlHelper .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> { if (access) { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 9e76a48c97..631462e773 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -7,7 +7,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; @@ -35,7 +34,6 @@ public class RestMLRegisterModelAction extends BaseRestHandler { private static final String ML_REGISTER_MODEL_ACTION = "ml_register_model_action"; - private volatile boolean isModelUrlAllowed; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** @@ -51,8 +49,6 @@ public RestMLRegisterModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting * @param settings settings */ public RestMLRegisterModelAction(ClusterService clusterService, Settings settings, MLFeatureEnabledSetting mlFeatureEnabledSetting) { - isModelUrlAllowed = ML_COMMONS_ALLOW_MODEL_URL.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ALLOW_MODEL_URL, it -> isModelUrlAllowed = it); this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @@ -103,11 +99,6 @@ MLRegisterModelRequest getRequest(RestRequest request) throws IOException { if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } - if (mlInput.getUrl() != null && !isModelUrlAllowed) { - throw new IllegalArgumentException( - "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." - ); - } return new MLRegisterModelRequest(mlInput); } } From f3d703df81e55f99d7e10852dd6ad9ee67fa144a Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Tue, 30 Jan 2024 16:15:46 -0800 Subject: [PATCH 3/4] Added test Signed-off-by: Owais Kazi --- .../TransportRegisterModelAction.java | 10 ++-- .../TransportRegisterModelActionTests.java | 47 +++++++++++++++++++ .../rest/RestMLRegisterModelActionTests.java | 14 ------ 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 64e7cc308c..e13ea03173 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -147,6 +147,11 @@ public TransportRegisterModelAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (registerModelInput.getUrl() != null && !isModelUrlAllowed) { + throw new IllegalArgumentException( + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." + ); + } registerModelInput.setIsHidden(RestActionUtils.isSuperAdminUser(clusterService, client)); if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) { mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> { @@ -174,11 +179,6 @@ private void checkUserAccess( Boolean isModelNameAlreadyExisting ) { User user = RestActionUtils.getUserContext(client); - if (registerModelInput.getUrl() != null && !isModelUrlAllowed) { - throw new IllegalArgumentException( - "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." - ); - } modelAccessControlHelper .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> { if (access) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index b40a278289..83ecd01069 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -155,12 +156,14 @@ public void setup() throws IOException { settings = Settings .builder() .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true) .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) .build(); threadContext = new ThreadContext(settings); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_URL_REGEX, + ML_COMMONS_ALLOW_MODEL_URL, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); @@ -294,6 +297,50 @@ public void testDoExecute_invalidURL() { assertEquals("URL can't match trusted url regex", argumentCaptor.getValue().getMessage()); } + public void testRegisterModelUrlNotAllowed() throws Exception { + Settings settings = Settings + .builder() + .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) + .put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false) + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_TRUSTED_URL_REGEX, + ML_COMMONS_ALLOW_MODEL_URL, + ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + transportRegisterModelAction = new TransportRegisterModelAction( + transportService, + actionFilters, + modelHelper, + mlIndicesHandler, + mlModelManager, + mlTaskManager, + clusterService, + settings, + threadPool, + client, + nodeFilter, + mlTaskDispatcher, + mlStats, + modelAccessControlHelper, + connectorAccessControlHelper, + mlModelGroupManager + ); + + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener) + ); + assertEquals( + e.getMessage(), + "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use OpenSearch pre-trained models." + ); + } + public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() { when(node1.getId()).thenReturn("NodeId1"); when(node2.getId()).thenReturn("NodeId2"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index 8655a4eb06..b3f3e3f956 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -147,20 +147,6 @@ public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception { restMLRegisterModelAction.handleRequest(request, channel, client); } - public void testRegisterModelUrlNotAllowed() throws Exception { - settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build(); - ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL); - when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - restMLRegisterModelAction = new RestMLRegisterModelAction(clusterService, settings, mlFeatureEnabledSetting); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule - .expectMessage( - "To upload custom model user needs to enable allow_registering_model_via_url settings. Otherwise please use opensearch pre-trained models." - ); - RestRequest request = getRestRequest(); - restMLRegisterModelAction.handleRequest(request, channel, client); - } - public void testRegisterModelRequestWithNullUrlAndUrlNotAllowed() throws Exception { settings = Settings.builder().put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ALLOW_MODEL_URL); From d1348d97946d3c7bfb44aa8df2ab22e8f6a5cf19 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 31 Jan 2024 16:06:18 -0800 Subject: [PATCH 4/4] Fixed integ test Signed-off-by: Owais Kazi --- .../opensearch/ml/action/models/SearchModelITTests.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index d5c1347e26..5b2eae3f4a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -5,11 +5,14 @@ package org.opensearch.ml.action.models; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; + import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.settings.Settings; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.action.MLCommonsIntegTestCase; @@ -187,4 +190,9 @@ private void test_matchPhrase_search() { assertEquals(1, response.getHits().getTotalHits().value); } + @Override + protected Settings nodeSettings(int ordinal) { + return Settings.builder().put(super.nodeSettings(ordinal)).put(ML_COMMONS_ALLOW_MODEL_URL.getKey(), true).build(); + } + }