diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 039d9cb4f8..2079b93917 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -11,6 +11,7 @@ import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.Map; import org.apache.commons.lang3.StringUtils; @@ -43,6 +44,7 @@ import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -150,23 +152,30 @@ private void updateModelGroup( source.put(MLModelGroup.DESCRIPTION_FIELD, updateModelGroupInput.getDescription()); } if (StringUtils.isNotBlank(updateModelGroupInput.getName()) && !updateModelGroupInput.getName().equals(modelGroupName)) { - mlModelGroupManager - .validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(isModelGroupNameUnique -> { - if (Boolean.FALSE.equals(isModelGroupNameUnique)) { + mlModelGroupManager.validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(modelGroups -> { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + Iterator iterator = modelGroups.getHits().iterator(); + while (iterator.hasNext()) { + String id = iterator.next().getId(); listener .onFailure( new IllegalArgumentException( - "The name you provided is already being used by another model group. Please provide a different name." + "The name you provided is already being used by another model with ID: " + + id + + ". Please provide a different name" ) ); - } else { - source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName()); - updateModelGroup(modelGroupId, source, listener); } - }, e -> { - log.error("Failed to search model group index", e); - listener.onFailure(e); - })); + } else { + source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName()); + updateModelGroup(modelGroupId, source, listener); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + })); } else { updateModelGroup(modelGroupId, source, listener); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index a55ef27ead..9bd80f9fd2 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -9,10 +9,12 @@ import java.time.Instant; import java.util.HashSet; +import java.util.Iterator; import org.opensearch.action.ActionListener; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -32,6 +34,7 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import lombok.extern.log4j.Log4j2; @@ -62,11 +65,22 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener { - if (Boolean.FALSE.equals(isUniqueModelGroupName)) { - throw new IllegalArgumentException( - "The name you provided is already being used by another model group. Please provide a different name" - ); + validateUniqueModelGroupName(input.getName(), ActionListener.wrap(modelGroups -> { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + Iterator iterator = modelGroups.getHits().iterator(); + while (iterator.hasNext()) { + String id = iterator.next().getId(); + listener + .onFailure( + new IllegalArgumentException( + "The name you provided is already being used by another model with ID: " + + id + + ". Please provide a different name" + ) + ); + } } else { MLModelGroup.MLModelGroupBuilder builder = MLModelGroup.builder(); MLModelGroup mlModelGroup; @@ -170,21 +184,16 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us } } - public void validateUniqueModelGroupName(String name, ActionListener listener) throws IllegalArgumentException { + public void validateUniqueModelGroupName(String name, ActionListener listener) throws IllegalArgumentException { BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(modelGroups -> { - listener - .onResponse( - modelGroups == null || modelGroups.getHits().getTotalHits() == null || modelGroups.getHits().getTotalHits().value == 0 - ); - }, e -> { + client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { if (e instanceof IndexNotFoundException) { - listener.onResponse(true); + listener.onResponse(null); } else { log.error("Failed to search model group index", e); listener.onFailure(e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index 056d1e4337..5f70123936 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -6,7 +6,9 @@ package org.opensearch.ml.action.model_group; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -14,6 +16,7 @@ import java.util.Arrays; import java.util.List; +import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -22,6 +25,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -45,6 +49,9 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -134,9 +141,10 @@ public void setup() throws IOException { return null; }).when(client).get(any(), any()); + SearchResponse searchResponse = createModelGroupSearchResponse(0); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(true); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); return null; }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); @@ -385,13 +393,14 @@ public void test_SuccessSecurityDisabledCluster() { verify(actionListener).onResponse(argumentCaptor.capture()); } - public void test_ModelGroupNameNotUnique() { + public void test_ModelGroupNameNotUnique() throws IOException { + SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(false); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true); when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true); @@ -402,7 +411,7 @@ public void test_ModelGroupNameNotUnique() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "The name you provided is already being used by another model group. Please provide a different name.", + "The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name", argumentCaptor.getValue().getMessage() ); } @@ -432,4 +441,21 @@ private MLUpdateModelGroupRequest prepareRequest(List backendRoles, Acce return new MLUpdateModelGroupRequest(UpdateModelGroupInput); } + private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { + SearchResponse searchResponse = mock(SearchResponse.class); + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"access\": \"public\",\n" + + " \"latest_version\": 0,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_group_ID\",\n" + + " \"name\": \"model_group_IT\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java index bf73addf04..b5300ecdb4 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java @@ -17,7 +17,6 @@ import java.util.List; import org.junit.Before; -import org.junit.Ignore; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -108,7 +107,6 @@ public void test_UndefinedOwner() throws IOException { assertTrue(argumentCaptor.getValue()); } - @Ignore public void test_ExceptionEmptyBackendRoles() throws IOException { String owner = "owner|IT,HR|myTenant"; User user = User.parse("owner|IT,HR|myTenant"); @@ -119,7 +117,6 @@ public void test_ExceptionEmptyBackendRoles() throws IOException { assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_MatchingBackendRoles() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); @@ -131,7 +128,6 @@ public void test_MatchingBackendRoles() throws IOException { assertTrue(argumentCaptor.getValue()); } - @Ignore public void test_PublicModelGroup() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); @@ -143,7 +139,6 @@ public void test_PublicModelGroup() throws IOException { assertTrue(argumentCaptor.getValue()); } - @Ignore public void test_PrivateModelGroupWithSameOwner() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); @@ -155,7 +150,6 @@ public void test_PrivateModelGroupWithSameOwner() throws IOException { assertTrue(argumentCaptor.getValue()); } - @Ignore public void test_PrivateModelGroupWithDifferentOwner() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 54a78fc921..e7b195d269 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -6,15 +6,18 @@ package org.opensearch.ml.model; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.IOException; import java.util.Arrays; import java.util.List; +import org.apache.lucene.search.TotalHits; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; @@ -22,6 +25,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -33,12 +37,14 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -@Ignore public class MLModelGroupManagerTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -81,7 +87,7 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { private final List backendRoles = Arrays.asList("IT", "HR"); @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -102,6 +108,13 @@ public void setup() { return null; }).when(mlIndicesHandler).initModelGroupIndexIfAbsent(any()); + SearchResponse searchResponse = createModelGroupSearchResponse(0); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); } @@ -116,6 +129,26 @@ public void test_SuccessAddAllBackendRolesTrue() { verify(actionListener).onResponse(argumentCaptor.capture()); } + public void test_ModelGroupNameNotUnique() throws IOException { + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + + MLRegisterModelGroupInput mlRegisterModelGroupInput = prepareRequest(null, null, true); + + mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name", + argumentCaptor.getValue().getMessage() + ); + + } + public void test_SuccessPublic() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -125,17 +158,13 @@ public void test_SuccessPublic() { verify(actionListener).onResponse(argumentCaptor.capture()); } - public void test_ExceptionAllAccessFieldsNull() { + public void test_DefaultPrivateModelGroup() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); MLRegisterModelGroupInput mlRegisterModelGroupInput = prepareRequest(null, null, null); mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "You must specify at least one backend role or make the model group public/private for registering it.", - argumentCaptor.getValue().getMessage() - ); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(String.class); + verify(actionListener).onResponse(argumentCaptor.capture()); } public void test_ModelAccessModeNullAddAllBackendRolesTrue() { @@ -158,6 +187,16 @@ public void test_BackendRolesProvidedWithPublic() { assertEquals("You can specify backend roles only for a model group with the restricted access mode.", argumentCaptor.getValue().getMessage()); } + public void test_ProvidedBothBackendRolesAndAddAllBackendRolesWithNoAccessMode() { + when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); + + MLRegisterModelGroupInput mlRegisterModelGroupInput = prepareRequest(backendRoles, null, true); + mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You cannot specify backend roles and add all backend roles at the same time.", argumentCaptor.getValue().getMessage()); + } + public void test_BackendRolesProvidedWithPrivate() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -307,4 +346,21 @@ private MLRegisterModelGroupInput prepareRequest(List backendRoles, Acce .build(); } + private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { + SearchResponse searchResponse = mock(SearchResponse.class); + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"access\": \"public\",\n" + + " \"latest_version\": 0,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_group_ID\",\n" + + " \"name\": \"model_group_IT\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + }