From 52fb87ec9bdbeca8b8e9bb0580949a96399930ab Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 6 Oct 2023 16:24:10 -0700 Subject: [PATCH] if model version fails to register, update model group accordingly (#1463) * if model version fails to register, update model group accordingly Signed-off-by: Bhavana Ramaram (cherry picked from commit e2d27785a3de1de00fa132695162f3112fa685a5) --- .../register/MLRegisterModelInput.java | 24 ++++++-- .../MLRegisterModelMetaInput.java | 18 +++++- .../MLRegisterModelMetaInputTest.java | 2 +- .../MLRegisterModelMetaRequestTest.java | 2 +- .../TransportRegisterModelAction.java | 2 + .../TransportRegisterModelMetaAction.java | 2 + .../ml/model/MLModelGroupManager.java | 53 +++++++++-------- .../opensearch/ml/model/MLModelManager.java | 58 ++++++++++++++----- 8 files changed, 117 insertions(+), 44 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index c778675ea8..6697461429 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -53,7 +53,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String ACCESS_MODE_FIELD = "access_mode"; public static final String BACKEND_ROLES_FIELD = "backend_roles"; public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; - + public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; private FunctionName functionName; private String modelName; private String modelGroupId; @@ -73,6 +73,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private List backendRoles; private Boolean addAllBackendRoles; private AccessMode accessMode; + private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, @@ -90,7 +91,8 @@ public MLRegisterModelInput(FunctionName functionName, String connectorId, List backendRoles, Boolean addAllBackendRoles, - AccessMode accessMode + AccessMode accessMode, + Boolean doesVersionCreateModelGroup ) { if (functionName == null) { this.functionName = FunctionName.TEXT_EMBEDDING; @@ -123,6 +125,7 @@ public MLRegisterModelInput(FunctionName functionName, this.backendRoles = backendRoles; this.addAllBackendRoles = addAllBackendRoles; this.accessMode = accessMode; + this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } @@ -157,6 +160,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { if (in.readBoolean()) { this.accessMode = in.readEnum(AccessMode.class); } + this.doesVersionCreateModelGroup = in.readOptionalBoolean(); } @Override @@ -202,6 +206,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalBoolean(doesVersionCreateModelGroup); } @Override @@ -249,6 +254,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (accessMode != null) { builder.field(ACCESS_MODE_FIELD, accessMode); } + if (doesVersionCreateModelGroup != null) { + builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); + } builder.endObject(); return builder; } @@ -267,6 +275,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName List backendRoles = new ArrayList<>(); Boolean addAllBackendRoles = null; AccessMode accessMode = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -318,12 +327,15 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -342,6 +354,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo List backendRoles = new ArrayList<>(); AccessMode accessMode = null; Boolean addAllBackendRoles = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -400,11 +413,14 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case ACCESS_MODE_FIELD: accessMode = AccessMode.from(parser.text()); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index d23f00caf7..ecb03d9bb6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -46,6 +46,8 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional public static final String ACCESS_MODE = "access_mode"; //optional public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; + private FunctionName functionName; private String name; @@ -65,11 +67,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private List backendRoles; private AccessMode accessMode; private Boolean isAddAllBackendRoles; + private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, AccessMode accessMode, - Boolean isAddAllBackendRoles) { + Boolean isAddAllBackendRoles, + Boolean doesVersionCreateModelGroup) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -103,6 +107,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.backendRoles = backendRoles; this.accessMode = accessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; + this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } public MLRegisterModelMetaInput(StreamInput in) throws IOException{ @@ -128,6 +133,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{ accessMode = in.readEnum(AccessMode.class); } this.isAddAllBackendRoles = in.readOptionalBoolean(); + this.doesVersionCreateModelGroup = in.readOptionalBoolean(); } @Override @@ -171,6 +177,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isAddAllBackendRoles); + out.writeOptionalBoolean(doesVersionCreateModelGroup); } @Override @@ -206,6 +213,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isAddAllBackendRoles != null) { builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles); } + if (doesVersionCreateModelGroup != null) { + builder.field(DOES_VERSION_CREATE_MODEL_GROUP, doesVersionCreateModelGroup); + } builder.endObject(); return builder; } @@ -225,6 +235,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc List backendRoles = null; AccessMode accessMode = null; Boolean isAddAllBackendRoles = null; + Boolean doesVersionCreateModelGroup = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -277,12 +288,15 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case ADD_ALL_BACKEND_ROLES: isAddAllBackendRoles = parser.booleanValue(); break; + case DOES_VERSION_CREATE_MODEL_GROUP: + doesVersionCreateModelGroup = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles); + return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index c9ace159ee..61e57d4ac6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -43,7 +43,7 @@ public void setup() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index 0c3a432d94..d7039780f0 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -33,7 +33,7 @@ public void setUp() { config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null); + "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); } @Test diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index a1d1d54258..c8955469ef 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -256,12 +256,14 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(registerModelInput); mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { registerModelInput.setModelGroupId(modelGroupId); + registerModelInput.setDoesVersionCreateModelGroup(true); registerModel(registerModelInput, listener); }, e -> { logException("Failed to create Model Group", e, log); listener.onFailure(e); })); } else { + registerModelInput.setDoesVersionCreateModelGroup(false); registerModel(registerModelInput, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index da350ef039..a730de712f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -121,12 +121,14 @@ private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionList MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput); mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> { mlUploadInput.setModelGroupId(modelGroupId); + mlUploadInput.setDoesVersionCreateModelGroup(true); registerModelMeta(mlUploadInput, listener); }, e -> { logException("Failed to create Model Group", e, log); listener.onFailure(e); })); } else { + mlUploadInput.setDoesVersionCreateModelGroup(false); registerModelMeta(mlUploadInput, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index e833910d58..94cbcf5364 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -146,15 +146,13 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us AccessMode modelAccessMode = input.getModelAccessMode(); Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles(); if (modelAccessMode == null) { - if (modelAccessMode == null) { - if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { - throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); - } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { - input.setModelAccessMode(AccessMode.RESTRICTED); - modelAccessMode = AccessMode.RESTRICTED; - } else { - input.setModelAccessMode(AccessMode.PRIVATE); - } + if (!CollectionUtils.isEmpty(input.getBackendRoles()) && Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("You cannot specify backend roles and add all backend roles at the same time."); + } else if (Boolean.TRUE.equals(isAddAllBackendRoles) || !CollectionUtils.isEmpty(input.getBackendRoles())) { + input.setModelAccessMode(AccessMode.RESTRICTED); + modelAccessMode = AccessMode.RESTRICTED; + } else { + input.setModelAccessMode(AccessMode.PRIVATE); } } if ((AccessMode.PUBLIC == modelAccessMode || AccessMode.PRIVATE == modelAccessMode) @@ -184,20 +182,29 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us } public void validateUniqueModelGroupName(String name, ActionListener listener) throws IllegalArgumentException { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); - - client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onResponse(null); - } else { - log.error("Failed to search model group index", e); - listener.onFailure(e); - } - })); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onResponse(null); + } else { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } + }), () -> context.restore()) + ); + } catch (Exception e) { + log.error("Failed to search model group index", e); + listener.onFailure(e); + } } private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 759b0cec9f..79c3f625df 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -295,6 +295,11 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput log.debug("Index model meta doc successfully {}", modelName); wrappedListener.onResponse(response.getId()); }, e -> { + deleteOrUpdateModelGroup( + mlRegisterModelMetaInput.getModelGroupId(), + mlRegisterModelMetaInput.getDoesVersionCreateModelGroup(), + version + ); log.error("Failed to index model meta doc", e); wrappedListener.onFailure(e); })); @@ -327,10 +332,6 @@ public void registerMLRemoteModel( String modelGroupId = mlRegisterModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - indexRemoteModel(mlRegisterModelInput, mlTask, "1", listener); - } - client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { if (getModelGroupResponse.isExists()) { Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); @@ -398,9 +399,6 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa String modelGroupId = registerModelInput.getModelGroupId(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - if (Strings.isBlank(modelGroupId)) { - uploadModel(registerModelInput, mlTask, "1"); - } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { @@ -722,7 +720,8 @@ private void registerModel( modelId, modelSizeInBytes, chunkFiles, - hashValue + hashValue, + version ); } else { deleteFileQuietly(file); @@ -734,7 +733,7 @@ private void registerModel( handleException(functionName, taskId, e); deleteFileQuietly(file); // remove model doc as failed to upload model - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); semaphore.release(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); })); @@ -742,7 +741,7 @@ private void registerModel( }, e -> { log.error("Failed to index chunk file", e); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); handleException(functionName, taskId, e); }) ); @@ -783,7 +782,8 @@ private void updateModelRegisterStateAsDone( String modelId, Long modelSizeInBytes, List chunkFiles, - String hashValue + String hashValue, + String version ) { FunctionName functionName = registerModelInput.getFunctionName(); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); @@ -809,7 +809,7 @@ private void updateModelRegisterStateAsDone( }, e -> { log.error("Failed to update model", e); handleException(functionName, taskId, e); - deleteModel(modelId); + deleteModel(modelId, registerModelInput, version); })); } @@ -822,7 +822,7 @@ private void deployModelAfterRegistering(MLRegisterModelInput registerModelInput client.execute(MLDeployModelAction.INSTANCE, request, listener); } - private void deleteModel(String modelId) { + private void deleteModel(String modelId, MLRegisterModelInput registerModelInput, String modelVersion) { DeleteRequest deleteRequest = new DeleteRequest(); deleteRequest.index(ML_MODEL_INDEX).id(modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.delete(deleteRequest); @@ -831,6 +831,38 @@ private void deleteModel(String modelId) { .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) .setAbortOnVersionConflict(false); client.execute(DeleteByQueryAction.INSTANCE, deleteChunksRequest); + deleteOrUpdateModelGroup(registerModelInput.getModelGroupId(), registerModelInput.getDoesVersionCreateModelGroup(), modelVersion); + } + + private void deleteOrUpdateModelGroup(String modelGroupID, Boolean doesVersionCreateModelGroup, String modelVersion) { + // This checks if model group is created when registering the version. If yes, model group is deleted since the version registration + // had failed. Else model group latest version is decremented by 1 + if (doesVersionCreateModelGroup) { + DeleteRequest deleteModelGroupRequest = new DeleteRequest(); + deleteModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupID).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteModelGroupRequest); + } else { + updateLatestVersionInModelGroup( + modelGroupID, + Integer.parseInt(modelVersion) - 1, + ActionListener + .wrap(r -> log.debug("model group updated, response {}", r), e -> log.error("Failed to update model group", e)) + ); + } + } + + private void updateLatestVersionInModelGroup(String modelGroupID, Integer latestVersion, ActionListener listener) { + Map updatedFields = new HashMap<>(); + updatedFields.put(MLModelGroup.LATEST_VERSION_FIELD, latestVersion); + updatedFields.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_GROUP_INDEX, modelGroupID); + updateRequest.doc(updatedFields); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + } catch (Exception e) { + listener.onFailure(e); + } } private void handleException(FunctionName functionName, String taskId, Exception e) {