diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index c778675ea8..6697461429 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -53,7 +53,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String ACCESS_MODE_FIELD = "access_mode"; public static final String BACKEND_ROLES_FIELD = "backend_roles"; public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; - + public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; private FunctionName functionName; private String modelName; private String modelGroupId; @@ -73,6 +73,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private List backendRoles; private Boolean addAllBackendRoles; private AccessMode accessMode; + private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, @@ -90,7 +91,8 @@ public MLRegisterModelInput(FunctionName functionName, String connectorId, List backendRoles, Boolean addAllBackendRoles, - AccessMode accessMode + AccessMode accessMode, + Boolean doesVersionCreateModelGroup ) { if (functionName == null) { this.functionName = FunctionName.TEXT_EMBEDDING; @@ -123,6 +125,7 @@ public MLRegisterModelInput(FunctionName functionName, this.backendRoles = backendRoles; this.addAllBackendRoles = addAllBackendRoles; this.accessMode = accessMode; + this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } @@ -157,6 +160,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { if (in.readBoolean()) { this.accessMode = in.readEnum(AccessMode.class); } + this.doesVersionCreateModelGroup = in.readOptionalBoolean(); } @Override @@ -202,6 +206,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalBoolean(doesVersionCreateModelGroup); } @Override @@ -249,6 +254,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (accessMode != null) { builder.field(ACCESS_MODE_FIELD, accessMode); } + if (doesVersionCreateModelGroup != null) { + builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); + } builder.endObject(); return builder; } @@ -267,6 +275,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName List backendRoles = new ArrayList<>(); Boolean addAllBackendRoles = null; AccessMode accessMode = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -318,12 +327,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -342,6 +354,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo List backendRoles = new ArrayList<>(); AccessMode accessMode = null; Boolean addAllBackendRoles = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -400,11 +413,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index d23f00caf7..ecb03d9bb6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -46,6 +46,8 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional public static final String ACCESS_MODE = "access_mode"; //optional public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; + private FunctionName functionName; private String name; @@ -65,11 +67,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private List backendRoles; private AccessMode accessMode; private Boolean isAddAllBackendRoles; + private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, AccessMode accessMode, - Boolean isAddAllBackendRoles) { + Boolean isAddAllBackendRoles, + Boolean doesVersionCreateModelGroup) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -103,6 +107,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.backendRoles = backendRoles; this.accessMode = accessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; + this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } public MLRegisterModelMetaInput(StreamInput in) throws IOException{ @@ -128,6 +133,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{ accessMode = in.readEnum(AccessMode.class); } this.isAddAllBackendRoles = in.readOptionalBoolean(); + this.doesVersionCreateModelGroup = in.readOptionalBoolean(); } @Override @@ -171,6 +177,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isAddAllBackendRoles); + out.writeOptionalBoolean(doesVersionCreateModelGroup); } @Override @@ -206,6 +213,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isAddAllBackendRoles != null) { builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles); } + if (doesVersionCreateModelGroup != null) { + builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); + } builder.endObject(); return builder; } @@ -225,6 +235,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc List backendRoles = null; AccessMode accessMode = null; Boolean isAddAllBackendRoles = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -277,12 +288,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case ADD_ALL_BACKEND_ROLES: isAddAllBackendRoles = parser.booleanValue(); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles); + return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index c9ace159ee..61e57d4ac6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -43,7 +43,7 @@ public void setup() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index 0c3a432d94..d7039780f0 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -33,7 +33,7 @@ public void setUp() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index a1d1d54258..c8955469ef 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -256,12 +256,14 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput); mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { registerModelInput.setModelGroupId(modelGroupId); + registerModelInput.setDoesVersionCreateModelGroup(true); registerModel(registerModelInput, listener); }, e -> { logException("Failed to create Model Group", e, log); listener.onFailure(e); })); } else { + registerModelInput.setDoesVersionCreateModelGroup(false); registerModel(registerModelInput, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index da350ef039..a730de712f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -121,12 +121,14 @@ private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionList MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput); mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { mlUploadInput.setModelGroupId(modelGroupId); + mlUploadInput.setDoesVersionCreateModelGroup(true); registerModelMeta(mlUploadInput, listener); }, e -> { logException("Failed to create Model Group", e, log); listener.onFailure(e); })); } else { + mlUploadInput.setDoesVersionCreateModelGroup(false); registerModelMeta(mlUploadInput, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index e833910d58..94cbcf5364 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -146,15 +146,13 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us AccessMode modelAccessMode = input.getModelAccessMode(); Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles(); if (modelAccessMode == null) { - if (modelAccessMode == null) { - if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { - throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); - } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { - input.setModelAccessMode(AccessMode.RESTRICTED); - modelAccessMode = AccessMode.RESTRICTED; - } else { - input.setModelAccessMode(AccessMode.PRIVATE); - } + if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); + } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { + input.setModelAccessMode(AccessMode.RESTRICTED); + modelAccessMode = AccessMode.RESTRICTED; + } else { + input.setModelAccessMode(AccessMode.PRIVATE); } } if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode) @@ -184,20 +182,29 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us } public void validateUniqueModelGroupName(String name, ActionListener listener) throws IllegalArgumentException { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); - - client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onResponse(null); - } else { - log.error("Failed to search model group index", e); - listener.onFailure(e); - } - })); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onResponse(null); + } else { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } + }), () -> context.restore()) + ); + } catch (Exception e) { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } } private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 759b0cec9f..0f8ba2ff4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -5,62 +5,10 @@ package org.opensearch.ml.model; -import static org.opensearch.common.xcontent.XContentType.JSON; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.CommonValue.NOT_FOUND; -import static org.opensearch.ml.common.CommonValue.UNDEPLOYED; -import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; -import static org.opensearch.ml.common.MLTask.ERROR_FIELD; -import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; -import static org.opensearch.ml.common.MLTask.STATE_FIELD; -import static org.opensearch.ml.common.MLTaskState.COMPLETED; -import static org.opensearch.ml.common.MLTaskState.FAILED; -import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES; -import static org.opensearch.ml.engine.ModelHelper.CHUNK_SIZE; -import static org.opensearch.ml.engine.ModelHelper.MODEL_FILE_HASH; -import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; -import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; -import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; -import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE; -import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.XCONTENT_REGISTRY; -import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE; -import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.MODEL_HELPER; -import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.MODEL_ZIP_FILE; -import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash; -import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; -import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; -import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; -import static org.opensearch.ml.stats.ActionName.REGISTER; -import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT; -import static org.opensearch.ml.utils.MLExceptionUtils.logException; -import static org.opensearch.ml.utils.MLNodeUtils.checkOpenCircuitBreaker; -import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; - -import java.io.File; -import java.nio.file.Path; -import java.security.PrivilegedActionException; -import java.time.Instant; -import java.util.Arrays; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Supplier; - +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.io.Files; +import lombok.extern.log4j.Log4j2; import org.apache.logging.log4j.util.Strings; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.delete.DeleteRequest; @@ -123,11 +71,61 @@ import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.threadpool.ThreadPool; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.io.Files; +import java.io.File; +import java.nio.file.Path; +import java.security.PrivilegedActionException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.common.xcontent.XContentType.JSON; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.NOT_FOUND; +import static org.opensearch.ml.common.CommonValue.UNDEPLOYED; +import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; +import static org.opensearch.ml.common.MLTask.ERROR_FIELD; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTaskState.COMPLETED; +import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES; +import static org.opensearch.ml.engine.ModelHelper.CHUNK_SIZE; +import static org.opensearch.ml.engine.ModelHelper.MODEL_FILE_HASH; +import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.XCONTENT_REGISTRY; +import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE; +import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.MODEL_HELPER; +import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.MODEL_ZIP_FILE; +import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash; +import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; +import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; +import static org.opensearch.ml.stats.ActionName.REGISTER; +import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT; +import static org.opensearch.ml.utils.MLExceptionUtils.logException; +import static org.opensearch.ml.utils.MLNodeUtils.checkOpenCircuitBreaker; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; /** * Manager class for ML models. It contains ML model related operations like register, deploy model etc. @@ -295,6 +293,11 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput log.debug("Index model meta doc successfully {}", modelName); wrappedListener.onResponse(response.getId()); }, e -> { + deleteOrUpdateModelGroup( + mlRegisterModelMetaInput.getModelGroupId(), + mlRegisterModelMetaInput.getDoesVersionCreateModelGroup(), + version + ); log.error("Failed to index model meta doc", e); wrappedListener.onFailure(e); })); @@ -327,10 +330,6 @@ public void registerMLRemoteModel( String modelGroupId = mlRegisterModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - indexRemoteModel(mlRegisterModelInput, mlTask, "1", listener); - } - client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { if (getModelGroupResponse.isExists()) { Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); @@ -398,9 +397,6 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa String modelGroupId = registerModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - uploadModel(registerModelInput, mlTask, "1"); - } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { @@ -722,7 +718,8 @@ private void registerModel( modelId, modelSizeInBytes, chunkFiles, - hashValue + hashValue, + version ); } else { deleteFileQuietly(file); @@ -734,7 +731,7 @@ private void registerModel( handleException(functionName, taskId, e); deleteFileQuietly(file); // remove model doc as failed to upload model - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); semaphore.release(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); })); @@ -742,7 +739,7 @@ private void registerModel( }, e -> { log.error("Failed to index chunk file", e); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); handleException(functionName, taskId, e); }) ); @@ -783,7 +780,8 @@ private void updateModelRegisterStateAsDone( String modelId, Long modelSizeInBytes, List chunkFiles, - String hashValue + String hashValue, + String version ) { FunctionName functionName = registerModelInput.getFunctionName(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); @@ -809,7 +807,7 @@ private void updateModelRegisterStateAsDone( }, e -> { log.error("Failed to update model", e); handleException(functionName, taskId, e); - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); })); } @@ -822,7 +820,7 @@ private void deployModelAfterRegistering(MLRegisterModelInput registerModelInput client.execute(MLDeployModelAction.INSTANCE, request, listener); } - private void deleteModel(String modelId) { + private void deleteModel(String modelId, MLRegisterModelInput registerModelInput, String modelVersion) { DeleteRequest deleteRequest = new DeleteRequest(); deleteRequest.index(ML_MODEL_INDEX).id(modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.delete(deleteRequest); @@ -831,6 +829,38 @@ private void deleteModel(String modelId) { .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) .setAbortOnVersionConflict(false); client.execute(DeleteByQueryAction.INSTANCE, deleteChunksRequest); + deleteOrUpdateModelGroup(registerModelInput.getModelGroupId(), registerModelInput.getDoesVersionCreateModelGroup(), modelVersion); + } + + private void deleteOrUpdateModelGroup(String modelGroupID, Boolean doesVersionCreateModelGroup, String modelVersion) { + // This checks if model group is created when registering the version. If yes, model group is deleted since the version registration + // had failed. Else model group latest version is decremented by 1 + if (doesVersionCreateModelGroup) { + DeleteRequest deleteModelGroupRequest = new DeleteRequest(); + deleteModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupID).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteModelGroupRequest); + } else { + updateLatestVersionInModelGroup( + modelGroupID, + Integer.parseInt(modelVersion) - 1, + ActionListener + .wrap(r -> log.debug("model group updated, response {}", r), e -> log.error("Failed to update model group", e)) + ); + } + } + + private void updateLatestVersionInModelGroup(String modelGroupID, Integer latestVersion, ActionListener listener) { + Map updatedFields = new HashMap<>(); + updatedFields.put(MLModelGroup.LATEST_VERSION_FIELD, latestVersion); + updatedFields.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_GROUP_INDEX, modelGroupID); + updateRequest.doc(updatedFields); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + listener.onFailure(e); + } } private void handleException(FunctionName functionName, String taskId, Exception e) {