Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix model group auto-deletion when last version is deleted #1444

Merged
merged 3 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@
package org.opensearch.ml.action.models;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD;
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
Expand All @@ -34,8 +30,6 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
Expand All @@ -48,7 +42,6 @@
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -110,6 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString();
}
MLModel mlModel = MLModel.parse(parser, algorithmName);
MLModelState mlModelState = mlModel.getModelState();

modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
Expand All @@ -118,37 +112,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
.onFailure(
new MLValidationException("User doesn't have privilege to perform this operation on this model")
);
} else if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING)
|| mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
actionListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
)
);
} else {
MLModelState mlModelState = mlModel.getModelState();
if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING)
|| mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
wrappedListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
)
);
} else if (StringUtils.isNotEmpty(mlModel.getModelGroupId())) {
searchModel(mlModel.getModelGroupId(), ActionListener.wrap(response -> {
boolean isLastModelOfGroup = false;
if (response != null
&& response.getHits() != null
&& response.getHits().getTotalHits() != null
&& response.getHits().getTotalHits().value == 1) {
isLastModelOfGroup = true;
}
deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, wrappedListener);
}, e -> {
log.error("Failed to Search Model index " + modelId, e);
wrappedListener.onFailure(e);
}));
} else {
deleteModel(modelId, mlModel.getModelGroupId(), false, wrappedListener);
}
deleteModel(modelId, actionListener);
}
}, e -> {
log.error("Failed to validate Access for Model Id " + modelId, e);
Expand All @@ -168,18 +145,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}
}

private void searchModel(String modelGroupId, ActionListener<SearchResponse> listener) {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLModel.MODEL_GROUP_ID_FIELD, modelGroupId));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);
client.search(searchRequest, ActionListener.wrap(response -> { listener.onResponse(response); }, e -> {
log.error("Failed to search Model index", e);
listener.onFailure(e);
}));
}

@VisibleForTesting
void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener<DeleteResponse> actionListener) {
DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX);
Expand Down Expand Up @@ -218,19 +183,11 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR));
}

private void deleteModel(
String modelId,
String modelGroupId,
boolean isLastModelOfGroup,
ActionListener<DeleteResponse> actionListener
) {
private void deleteModel(String modelId, ActionListener<DeleteResponse> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
if (isLastModelOfGroup) {
deleteModelGroup(modelGroupId);
}
deleteModelChunks(modelId, deleteResponse, actionListener);
}

Expand All @@ -244,19 +201,4 @@ public void onFailure(Exception e) {
}
});
}

private void deleteModelGroup(String modelGroupId) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Model Group for modelGroupId:{}", modelGroupId);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Model Group with Id:{} " + modelGroupId, e);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
package org.opensearch.ml.action.models;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -22,7 +20,6 @@
import java.util.ArrayList;
import java.util.Arrays;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
Expand All @@ -35,7 +32,6 @@
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -50,15 +46,11 @@
import org.opensearch.index.get.GetResult;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -184,115 +176,6 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException {
verify(actionListener).onResponse(deleteResponse);
}

public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), isA(ActionListener.class));

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void test_Success_ModelGroupIDNotNull_NotLastModelOfGroup() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

SearchResponse searchResponse = createModelGroupSearchResponse(2);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), isA(ActionListener.class));

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelGroupId("modelGroupID")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void test_Failure_FailedToSearchLastModel() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(new Exception("Failed to search Model index"));
return null;
}).when(client).search(any(), isA(ActionListener.class));

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to search Model index", argumentCaptor.getValue().getMessage());
}

public void test_UserHasNoAccessException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");
doAnswer(invocation -> {
Expand Down Expand Up @@ -517,20 +400,4 @@ public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID
GetResponse getResponse = new GetResponse(getResult);
return getResponse;
}

private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException {
SearchResponse searchResponse = mock(SearchResponse.class);
String modelContent = "{\n"
+ " \"created_time\": 1684981986069,\n"
+ " \"access\": \"public\",\n"
+ " \"latest_version\": 0,\n"
+ " \"last_updated_time\": 1684981986069,\n"
+ " \"name\": \"model_group_IT\",\n"
+ " \"description\": \"This is an example description\"\n"
+ " }";
SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent));
SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN);
when(searchResponse.getHits()).thenReturn(hits);
return searchResponse;
}
}