Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register new versions to a model group based on the name provided #1452

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.List;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -136,17 +137,76 @@ public TransportRegisterModelAction(

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterModelResponse> listener) {
User user = RestActionUtils.getUserContext(client);
MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request);
MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput();
if (StringUtils.isEmpty(registerModelInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(registerModelInput.getModelName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(registerModelInput, listener, true);
} else {
doRegister(registerModelInput, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
} else {
checkUserAccess(registerModelInput, listener, false);
}
}

private void checkUserAccess(
MLRegisterModelInput registerModelInput,
ActionListener<MLRegisterModelResponse> listener,
Boolean isModelNameAlreadyExisting
) {
User user = RestActionUtils.getUserContext(client);
modelAccessControlHelper
.validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
log.error("You don't have permissions to perform this operation on this model.");
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
} else {
if (access) {
doRegister(registerModelInput, listener);
return;
}
// if the user does not have access, we need to check three more conditions before throwing exception.
// if we are checking the access based on the name provided in the input, we let user know the name is already used by a
// model group they do not have access to.
if (isModelNameAlreadyExisting) {
// This case handles when user is using the same pre-trained model already registered by another user on the cluster.
// The only way here is for the user to first create model group and use its ID in the request
if (registerModelInput.getUrl() == null
&& registerModelInput.getFunctionName() != FunctionName.REMOTE
&& registerModelInput.getConnectorId() == null) {
listener
.onFailure(
new IllegalArgumentException(
"Without a model group ID, the system will use the model name {"
+ registerModelInput.getModelName()
+ "} to create a new model group. However, this name is taken by another group with id {"
+ registerModelInput.getModelGroupId()
+ "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request."
)
);
} else {
listener
.onFailure(
new IllegalArgumentException(
"The name {"
+ registerModelInput.getModelName()
+ "} you provided is unavailable because it is used by another model group with id {"
+ registerModelInput.getModelGroupId()
+ "} to which you do not have access. Please provide a different name."
)
);
}
return;
}
// if user does not have access to the model group ID provided in the input, we let user know they do not have access to the
// specified model group
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
}, listener::onFailure));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,74 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegi
MLRegisterModelMetaRequest registerModelMetaRequest = MLRegisterModelMetaRequest.fromActionRequest(request);
MLRegisterModelMetaInput mlUploadInput = registerModelMetaRequest.getMlRegisterModelMetaInput();

User user = RestActionUtils.getUserContext(client);
if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) {
mlModelGroupManager.validateUniqueModelGroupName(mlUploadInput.getName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId();
mlUploadInput.setModelGroupId(modelGroupIdOfTheNameProvided);
checkUserAccess(mlUploadInput, listener, true);
} else {
createModelGroup(mlUploadInput, listener);
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
}));
} else {
checkUserAccess(mlUploadInput, listener, false);
}
}

private void checkUserAccess(
MLRegisterModelMetaInput mlUploadInput,
ActionListener<MLRegisterModelMetaResponse> listener,
Boolean isModelNameAlreadyExisting
) {

User user = RestActionUtils.getUserContext(client);
modelAccessControlHelper.validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
if (access) {
createModelGroup(mlUploadInput, listener);
return;
}
if (isModelNameAlreadyExisting) {
listener
.onFailure(
new IllegalArgumentException(
"The name {"
+ mlUploadInput.getName()
+ "} you provided is unavailable because it is used by another model group with id {"
+ mlUploadInput.getModelGroupId()
+ "} to which you do not have access. Please provide a different name."
)
);
} else {
log.error("You don't have permissions to perform this operation on this model.");
listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model."));
} else {
if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) {
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
mlUploadInput.setModelGroupId(modelGroupId);
registerModelMeta(mlUploadInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
registerModelMeta(mlUploadInput, listener);
}
}
}, e -> {
logException("Failed to validate model access", e, log);
listener.onFailure(e);
}));
}

private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionListener<MLRegisterModelMetaResponse> listener) {
if (StringUtils.isEmpty(mlUploadInput.getModelGroupId())) {
MLRegisterModelGroupInput mlRegisterModelGroupInput = createRegisterModelGroupRequest(mlUploadInput);
mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, ActionListener.wrap(modelGroupId -> {
mlUploadInput.setModelGroupId(modelGroupId);
registerModelMeta(mlUploadInput, listener);
}, e -> {
logException("Failed to create Model Group", e, log);
listener.onFailure(e);
}));
} else {
registerModelMeta(mlUploadInput, listener);
}
}

private MLRegisterModelGroupInput createRegisterModelGroupRequest(MLRegisterModelMetaInput mlUploadInput) {
return MLRegisterModelGroupInput
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
Expand All @@ -30,6 +32,7 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -61,6 +64,9 @@
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -144,7 +150,7 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase {
private ConnectorAccessControlHelper connectorAccessControlHelper;

@Before
public void setup() {
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
settings = Settings
.builder()
Expand Down Expand Up @@ -199,6 +205,13 @@ public void setup() {
return null;
}).when(mlTaskDispatcher).dispatch(any(), any());

SearchResponse searchResponse = createModelGroupSearchResponse(0);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

when(clusterService.localNode()).thenReturn(node2);
when(node2.getId()).thenReturn("node2Id");

Expand Down Expand Up @@ -461,6 +474,97 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi
);
}

public void test_ModelNameAlreadyExists() throws IOException {
when(node1.getId()).thenReturn("NodeId1");
when(node2.getId()).thenReturn("NodeId2");
MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class);
doAnswer(invocation -> {
ActionListenerResponseHandler<MLForwardResponse> handler = invocation.getArgument(3);
handler.handleResponse(forwardResponse);
return null;
}).when(transportService).sendRequest(any(), any(), any(), any());
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", null), actionListener);
ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException {
SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(false);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
.modelName("huggingface/sentence-transformers/all-MiniLM-L12-v2")
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.version("1")
.build();

transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Without a model group ID, the system will use the model name {huggingface/sentence-transformers/all-MiniLM-L12-v2} to create a new model group. However, this name is taken by another group with id {model_group_ID} you can't access. To register this pre-trained model, create a new model group and use its ID in your request.",
argumentCaptor.getValue().getMessage()

);
}

public void test_FailureWhenSearchingModelGroupName() throws IOException {
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("Runtime exception"));
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener);

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Runtime exception", argumentCaptor.getValue().getMessage());
}

public void test_NoAccessWhenModelNameAlreadyExists() throws IOException {

SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(false);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener);

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"The name {Test Model} you provided is unavailable because it is used by another model group with id {model_group_ID} to which you do not have access. Please provide a different name.",
argumentCaptor.getValue().getMessage()
);
}

private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) {
MLRegisterModelInput registerModelInput = MLRegisterModelInput
.builder()
Expand All @@ -485,4 +589,22 @@ private MLRegisterModelRequest prepareRequest(String url, String modelGroupID) {
return new MLRegisterModelRequest(registerModelInput);
}

private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException {

SearchResponse searchResponse = mock(SearchResponse.class);
String modelContent = "{\n"
+ " \"created_time\": 1684981986069,\n"
+ " \"access\": \"public\",\n"
+ " \"latest_version\": 0,\n"
+ " \"last_updated_time\": 1684981986069,\n"
+ " \"_id\": \"model_group_ID\",\n"
+ " \"name\": \"Test Model\",\n"
+ " \"description\": \"This is an example description\"\n"
+ " }";
SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent));
SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN);
when(searchResponse.getHits()).thenReturn(hits);
return searchResponse;
}

}
Loading
Loading