From e0c6276148d10f2a233f1081fa36baeb4611070e Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 7 Sep 2023 11:21:53 -0700 Subject: [PATCH] check connector usage in deployed models before updating connector Signed-off-by: Xun Zhang --- .../UpdateConnectorTransportAction.java | 55 ++++++- .../TransportUpdateConnectorActionTests.java | 136 +++++++++++++++++- 2 files changed, 187 insertions(+), 4 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 21df3d31ba..d8a1d88a01 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -6,9 +6,11 @@ package org.opensearch.ml.action.connector; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.update.UpdateRequest; @@ -17,9 +19,17 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -33,17 +43,20 @@ public class UpdateConnectorTransportAction extends HandledTransportAction { if (Boolean.TRUE.equals(hasPermission)) { - client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + updateUndeployedConnector(connectorId, updateRequest, listener, context); } else { listener .onFailure( @@ -74,6 +87,44 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener, + ThreadContext.StoredContext context + ) { + SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + boolQueryBuilder.must(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); + boolQueryBuilder.must(QueryBuilders.idsQuery().addIds(mlModelManager.getAllModelIds())); + sourceBuilder.query(boolQueryBuilder); + searchRequest.source(sourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + SearchHit[] searchHits = searchResponse.getHits().getHits(); + if (searchHits.length == 0) { + client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + } else { + log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); + listener + .onFailure( + new MLValidationException( + searchHits.length + " models are still using this connector, please undeploy the models first!" + ) + ); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + return; + } + log.error("Failed to update ML connector: " + connectorId, e); + listener.onFailure(e); + + })); + } + private ActionListener getUpdateResponseListener( String connectorId, ActionListener actionListener, diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java index 7024666715..fc6020474a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java @@ -15,14 +15,20 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import java.io.IOException; import java.util.List; import java.util.Map; +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; @@ -34,8 +40,14 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -77,17 +89,22 @@ public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; + @Mock + MLModelManager mlModelManager; + ThreadContext threadContext; private Settings settings; private ShardId shardId; + private SearchResponse searchResponse; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); settings = Settings .builder() @@ -109,12 +126,28 @@ public void setup() { when(updateRequest.getConnectorId()).thenReturn(connector_id); when(updateRequest.getUpdateContent()).thenReturn(updateContent); + SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1); + searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + transportUpdateConnectorAction = new UpdateConnectorTransportAction( transportService, actionFilters, client, - connectorAccessControlHelper + connectorAccessControlHelper, + mlModelManager ); + + when(mlModelManager.getAllModelIds()).thenReturn(new String[] {}); shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); } @@ -126,6 +159,12 @@ public void test_execute_connectorAccessControl_success() { return null; }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(updateResponse); @@ -182,6 +221,13 @@ public void test_execute_UpdateWrongStatus() { listener.onResponse(true); return null; }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -200,6 +246,12 @@ public void test_execute_UpdateException() { return null; }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("update document failure")); @@ -211,4 +263,84 @@ public void test_execute_UpdateException() { verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("update document failure", argumentCaptor.getValue().getMessage()); } + + public void test_execute_SearchResponseNotEmpty() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(noneEmptySearchResponse()); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_SearchResponseError() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Error in Search Request")); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_SearchIndexNotFoundError() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IndexNotFoundException("Index not found!")); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + private SearchResponse noneEmptySearchResponse() throws IOException { + String modelContent = "{\"name\":\"Remote_Model\",\"algorithm\":\"Remote\",\"version\":1,\"connector_id\":\"test_id\"}"; + SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { model }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + return searchResponse; + } }