Skip to content

Commit

Permalink
modify error message when model group not unique is provided
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Jul 11, 2023
1 parent e3cb2e3 commit 80df7e5
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -43,6 +44,7 @@
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -150,23 +152,30 @@ private void updateModelGroup(
source.put(MLModelGroup.DESCRIPTION_FIELD, updateModelGroupInput.getDescription());
}
if (StringUtils.isNotBlank(updateModelGroupInput.getName()) && !updateModelGroupInput.getName().equals(modelGroupName)) {
mlModelGroupManager
.validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(isModelGroupNameUnique -> {
if (Boolean.FALSE.equals(isModelGroupNameUnique)) {
mlModelGroupManager.validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
Iterator<SearchHit> iterator = modelGroups.getHits().iterator();
while (iterator.hasNext()) {
String id = iterator.next().getId();
listener
.onFailure(
new IllegalArgumentException(
"The name you provided is already being used by another model group. Please provide a different name."
"The name you provided is already being used by another model with ID: "
+ id
+ ". Please provide a different name"
)
);
} else {
source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName());
updateModelGroup(modelGroupId, source, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
} else {
source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName());
updateModelGroup(modelGroupId, source, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
} else {
updateModelGroup(modelGroupId, source, listener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import java.time.Instant;
import java.util.HashSet;
import java.util.Iterator;

import org.opensearch.action.ActionListener;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -32,6 +34,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -62,11 +65,22 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
String modelName = input.getName();
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
validateUniqueModelGroupName(input.getName(), ActionListener.wrap(isUniqueModelGroupName -> {
if (Boolean.FALSE.equals(isUniqueModelGroupName)) {
throw new IllegalArgumentException(
"The name you provided is already being used by another model group. Please provide a different name"
);
validateUniqueModelGroupName(input.getName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
Iterator<SearchHit> iterator = modelGroups.getHits().iterator();
while (iterator.hasNext()) {
String id = iterator.next().getId();
listener
.onFailure(
new IllegalArgumentException(
"The name you provided is already being used by another model with ID: "
+ id
+ ". Please provide a different name"
)
);
}
} else {
MLModelGroup.MLModelGroupBuilder builder = MLModelGroup.builder();
MLModelGroup mlModelGroup;
Expand Down Expand Up @@ -170,21 +184,16 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
}
}

public void validateUniqueModelGroupName(String name, ActionListener<Boolean> listener) throws IllegalArgumentException {
public void validateUniqueModelGroupName(String name, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);

client.search(searchRequest, ActionListener.wrap(modelGroups -> {
listener
.onResponse(
modelGroups == null || modelGroups.getHits().getTotalHits() == null || modelGroups.getHits().getTotalHits().value == 0
);
}, e -> {
client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onResponse(true);
listener.onResponse(null);
} else {
log.error("Failed to search model group index", e);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
package org.opensearch.ml.action.model_group;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
Expand All @@ -22,6 +25,7 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
Expand All @@ -45,6 +49,9 @@
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupResponse;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -134,9 +141,10 @@ public void setup() throws IOException {
return null;
}).when(client).get(any(), any());

SearchResponse searchResponse = createModelGroupSearchResponse(0);
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(1);
listener.onResponse(true);
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

Expand Down Expand Up @@ -385,13 +393,14 @@ public void test_SuccessSecurityDisabledCluster() {
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_ModelGroupNameNotUnique() {
public void test_ModelGroupNameNotUnique() throws IOException {

SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(1);
listener.onResponse(false);
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());
}).when(client).search(any(), isA(ActionListener.class));

when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true);
when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true);
Expand All @@ -402,7 +411,7 @@ public void test_ModelGroupNameNotUnique() {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"The name you provided is already being used by another model group. Please provide a different name.",
"The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name",
argumentCaptor.getValue().getMessage()
);
}
Expand Down Expand Up @@ -432,4 +441,21 @@ private MLUpdateModelGroupRequest prepareRequest(List<String> backendRoles, Acce
return new MLUpdateModelGroupRequest(UpdateModelGroupInput);
}

private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException {
SearchResponse searchResponse = mock(SearchResponse.class);
String modelContent = "{\n"
+ " \"created_time\": 1684981986069,\n"
+ " \"access\": \"public\",\n"
+ " \"latest_version\": 0,\n"
+ " \"last_updated_time\": 1684981986069,\n"
+ " \"_id\": \"model_group_ID\",\n"
+ " \"name\": \"model_group_IT\",\n"
+ " \"description\": \"This is an example description\"\n"
+ " }";
SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent));
SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN);
when(searchResponse.getHits()).thenReturn(hits);
return searchResponse;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.util.List;

import org.junit.Before;
import org.junit.Ignore;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
Expand Down Expand Up @@ -108,7 +107,6 @@ public void test_UndefinedOwner() throws IOException {
assertTrue(argumentCaptor.getValue());
}

@Ignore
public void test_ExceptionEmptyBackendRoles() throws IOException {
String owner = "owner|IT,HR|myTenant";
User user = User.parse("owner|IT,HR|myTenant");
Expand All @@ -119,7 +117,6 @@ public void test_ExceptionEmptyBackendRoles() throws IOException {
assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage());
}

@Ignore
public void test_MatchingBackendRoles() throws IOException {
String owner = "owner|IT,HR|myTenant";
List<String> backendRoles = Arrays.asList("IT", "HR");
Expand All @@ -131,7 +128,6 @@ public void test_MatchingBackendRoles() throws IOException {
assertTrue(argumentCaptor.getValue());
}

@Ignore
public void test_PublicModelGroup() throws IOException {
String owner = "owner|IT,HR|myTenant";
List<String> backendRoles = Arrays.asList("IT", "HR");
Expand All @@ -143,7 +139,6 @@ public void test_PublicModelGroup() throws IOException {
assertTrue(argumentCaptor.getValue());
}

@Ignore
public void test_PrivateModelGroupWithSameOwner() throws IOException {
String owner = "owner|IT,HR|myTenant";
List<String> backendRoles = Arrays.asList("IT", "HR");
Expand All @@ -155,7 +150,6 @@ public void test_PrivateModelGroupWithSameOwner() throws IOException {
assertTrue(argumentCaptor.getValue());
}

@Ignore
public void test_PrivateModelGroupWithDifferentOwner() throws IOException {
String owner = "owner|IT,HR|myTenant";
List<String> backendRoles = Arrays.asList("IT", "HR");
Expand Down
Loading

0 comments on commit 80df7e5

Please sign in to comment.