diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 4fd0e3850a..070c13b204 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -40,7 +40,6 @@ public class MLModelGroup implements ToXContentObject { @Setter private String name; private String description; - @Setter private int latestVersion; private List backendRoles; private User owner; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index eaa1474709..6bafe81692 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -18,6 +18,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; @@ -38,12 +39,17 @@ public class MLUpdateModelInputTest { private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}"; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setUp() throws Exception { @@ -109,6 +115,18 @@ public void parse_Success() throws Exception { }); } + @Test + public void parse_WithNullFieldWithoutModel() throws Exception { + exceptionRule.expect(IllegalStateException.class); + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + @Test public void parse_WithIllegalFieldWithoutModel() throws Exception { testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index b03e6028fa..7b953c341a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; @@ -19,11 +20,11 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; @@ -79,7 +80,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { log.error("Failed to get connector index", e); - actionListener.onFailure(new IllegalArgumentException("Fail to find connector")); + actionListener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); } else { log.error("Failed to get ML connector " + connectorId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 86e4afe56f..066ca5f8a7 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -13,6 +13,7 @@ import java.util.Arrays; import java.util.List; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.search.SearchRequest; @@ -27,13 +28,13 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; import org.opensearch.ml.engine.MLEngine; @@ -136,10 +137,11 @@ private void updateUndeployedConnector( } listener .onFailure( - new MLValidationException( + new OpenSearchStatusException( searchHits.length + " models are still using this connector, please undeploy the models first: " - + Arrays.toString(modelIds.toArray(new String[0])) + + Arrays.toString(modelIds.toArray(new String[0])), + RestStatus.BAD_REQUEST ) ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index ea4116c365..e1583abb44 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -5,16 +5,23 @@ package org.opensearch.ml.action.models; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; import java.io.IOException; +import java.time.Instant; +import java.util.Map; +import java.util.Objects; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -24,11 +31,11 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; @@ -83,35 +90,32 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); MLModelState mlModelState = mlModel.getModelState(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { modelAccessControlHelper .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { if (hasPermission) { - if (!mlModelState.equals(MLModelState.LOADED) - && !mlModelState.equals(MLModelState.LOADING) - && !mlModelState.equals(MLModelState.PARTIALLY_LOADED) - && !mlModelState.equals(MLModelState.DEPLOYED) - && !mlModelState.equals(MLModelState.DEPLOYING) - && !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener); + if (isModelDeployed(mlModelState)) { + updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, actionListener, context); } else { actionListener .onFailure( - new MLValidationException( + new OpenSearchStatusException( "ML Model " + modelId - + " is in deploying or deployed state, please undeploy the models first!" + + " is in deploying or deployed state, please undeploy the models first!", + RestStatus.FORBIDDEN ) ); } } else { actionListener .onFailure( - new MLValidationException( - "User doesn't have privilege to perform this operation on this model, model ID " + modelId + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model, model ID " + modelId, + RestStatus.FORBIDDEN ) ); } @@ -130,8 +134,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener - .onFailure(new MLResourceNotFoundException("Failed to find model to update with the provided model id: " + modelId)) - ), () -> context.restore())); + .onFailure( + new OpenSearchStatusException( + "Failed to find model to update with the provided model id: " + modelId, + RestStatus.NOT_FOUND + ) + ) + )); } catch (Exception e) { log.error("Failed to update ML model for " + modelId, e); actionListener.onFailure(e); @@ -143,22 +152,29 @@ private void updateRemoteOrTextEmbeddingModel( MLUpdateModelInput updateModelInput, MLModel mlModel, User user, - ActionListener actionListener + ActionListener actionListener, + ThreadContext.StoredContext context ) { - String newModelGroupId = Strings.hasLength(updateModelInput.getModelGroupId()) ? updateModelInput.getModelGroupId() : null; + String newModelGroupId = (Strings.hasLength(updateModelInput.getModelGroupId()) + && !Objects.equals(updateModelInput.getModelGroupId(), mlModel.getModelGroupId())) ? updateModelInput.getModelGroupId() : null; String relinkConnectorId = Strings.hasLength(updateModelInput.getConnectorId()) ? updateModelInput.getConnectorId() : null; if (mlModel.getAlgorithm() == TEXT_EMBEDDING) { if (relinkConnectorId == null) { - updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener); + updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener, context); } else { actionListener - .onFailure(new IllegalArgumentException("Trying to update the connector or connector_id field on a local model")); + .onFailure( + new OpenSearchStatusException( + "Trying to update the connector or connector_id field on a local model", + RestStatus.BAD_REQUEST + ) + ); } } else { // mlModel.getAlgorithm() == REMOTE if (relinkConnectorId == null) { - updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener); + updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener, context); } else { updateModelWithRelinkStandAloneConnector( modelId, @@ -167,7 +183,8 @@ private void updateRemoteOrTextEmbeddingModel( mlModel, user, updateModelInput, - actionListener + actionListener, + context ); } } @@ -180,18 +197,27 @@ private void updateModelWithRelinkStandAloneConnector( MLModel mlModel, User user, MLUpdateModelInput updateModelInput, - ActionListener actionListener + ActionListener actionListener, + ThreadContext.StoredContext context ) { if (Strings.hasLength(mlModel.getConnectorId())) { connectorAccessControlHelper .validateConnectorAccess(client, relinkConnectorId, ActionListener.wrap(hasRelinkConnectorPermission -> { if (hasRelinkConnectorPermission) { - updateModelWithRegisteringToAnotherModelGroup(modelId, newModelGroupId, user, updateModelInput, actionListener); + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + user, + updateModelInput, + actionListener, + context + ); } else { actionListener .onFailure( - new MLValidationException( - "You don't have permission to update the connector, connector id: " + relinkConnectorId + new OpenSearchStatusException( + "You don't have permission to update the connector, connector id: " + relinkConnectorId, + RestStatus.FORBIDDEN ) ); } @@ -202,7 +228,10 @@ private void updateModelWithRelinkStandAloneConnector( } else { actionListener .onFailure( - new IllegalArgumentException("This remote does not have a connector_id field, maybe it uses an internal connector.") + new OpenSearchStatusException( + "This remote does not have a connector_id field, maybe it uses an internal connector.", + RestStatus.BAD_REQUEST + ) ); } } @@ -212,29 +241,40 @@ private void updateModelWithRegisteringToAnotherModelGroup( String newModelGroupId, User user, MLUpdateModelInput updateModelInput, - ActionListener actionListener + ActionListener actionListener, + ThreadContext.StoredContext context ) { UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); if (newModelGroupId != null) { modelAccessControlHelper.validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasRelinkPermission -> { if (hasRelinkPermission) { - mlModelGroupManager.getModelGroup(newModelGroupId, ActionListener.wrap(newModelGroup -> { - updateRequestConstructor(modelId, updateRequest, updateModelInput, newModelGroup, actionListener); + mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { + updateRequestConstructor( + modelId, + newModelGroupId, + updateRequest, + updateModelInput, + newModelGroupResponse, + actionListener, + context + ); }, exception -> actionListener .onFailure( - new MLResourceNotFoundException( + new OpenSearchStatusException( "Failed to find the model group with the provided model group id in the update model input, MODEL_GROUP_ID: " - + newModelGroupId + + newModelGroupId, + RestStatus.NOT_FOUND ) ) )); } else { actionListener .onFailure( - new MLValidationException( + new OpenSearchStatusException( "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID " - + newModelGroupId + + newModelGroupId, + RestStatus.FORBIDDEN ) ); } @@ -243,7 +283,7 @@ private void updateModelWithRegisteringToAnotherModelGroup( actionListener.onFailure(exception); })); } else { - updateRequestConstructor(modelId, updateRequest, updateModelInput, actionListener); + updateRequestConstructor(modelId, updateRequest, updateModelInput, actionListener, context); } } @@ -251,12 +291,13 @@ private void updateRequestConstructor( String modelId, UpdateRequest updateRequest, MLUpdateModelInput updateModelInput, - ActionListener actionListener + ActionListener actionListener, + ThreadContext.StoredContext context ) { try { updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); - client.update(updateRequest, getUpdateResponseListener(modelId, actionListener)); + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); } catch (IOException e) { log.error("Failed to build update request."); actionListener.onFailure(e); @@ -265,60 +306,94 @@ private void updateRequestConstructor( private void updateRequestConstructor( String modelId, + String newModelGroupId, UpdateRequest updateRequest, MLUpdateModelInput updateModelInput, - MLModelGroup newModelGroup, - ActionListener actionListener + GetResponse newModelGroupResponse, + ActionListener actionListener, + ThreadContext.StoredContext context ) { - String updatedVersion = incrementLatestVersion(newModelGroup); + Map newModelGroupSourceMap = newModelGroupResponse.getSourceAsMap(); + String updatedVersion = incrementLatestVersion(newModelGroupSourceMap); updateModelInput.setVersion(updatedVersion); + UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + newModelGroupSourceMap, + newModelGroupId, + newModelGroupResponse.getSeqNo(), + newModelGroupResponse.getPrimaryTerm(), + Integer.parseInt(updatedVersion) + ); try { updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.docAsUpsert(true); - client.update(updateRequest, getUpdateResponseListener(modelId, newModelGroup, updatedVersion, actionListener)); + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); + }, e -> { + log + .error( + "Failed to register ML model with model ID {} to the new model group with model group ID {}", + modelId, + newModelGroupId + ); + actionListener.onFailure(e); + })); } catch (IOException e) { log.error("Failed to build update request."); actionListener.onFailure(e); } } - private ActionListener getUpdateResponseListener(String modelId, ActionListener actionListener) { - return ActionListener.wrap(updateResponse -> { - if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { - log.info("Model id:{} failed update", modelId); - actionListener.onResponse(updateResponse); - return; - } - log.info("Completed Update Model Request, model id:{} updated", modelId); - actionListener.onResponse(updateResponse); - }, exception -> { - log.error("Failed to update ML model: " + modelId, exception); - actionListener.onFailure(exception); - }); - } - private ActionListener getUpdateResponseListener( String modelId, - MLModelGroup newModelGroup, - String updatedVersion, - ActionListener actionListener + ActionListener actionListener, + ThreadContext.StoredContext context ) { - return ActionListener.wrap(updateResponse -> { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { log.info("Model id:{} failed update", modelId); actionListener.onResponse(updateResponse); return; } - log.info("Completed Update Model Request, model id:{} updated", modelId); - newModelGroup.setLatestVersion(Integer.parseInt(updatedVersion)); + log.info("Successfully update ML model with model ID {}", modelId); actionListener.onResponse(updateResponse); }, exception -> { log.error("Failed to update ML model: " + modelId, exception); actionListener.onFailure(exception); - }); + }), context::restore); + } + + private String incrementLatestVersion(Map modelGroupSourceMap) { + return Integer.toString((int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1); + } + + private UpdateRequest createUpdateModelGroupRequest( + Map modelGroupSourceMap, + String modelGroupId, + long seqNo, + long primaryTerm, + int updatedVersion + ) { + modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); + modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateRequest updateModelGroupRequest = new UpdateRequest(); + + updateModelGroupRequest + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .doc(modelGroupSourceMap); + + return updateModelGroupRequest; } - private String incrementLatestVersion(MLModelGroup mlModelGroup) { - return Integer.toString(mlModelGroup.getLatestVersion() + 1); + private Boolean isModelDeployed(MLModelState mlModelState) { + return !mlModelState.equals(MLModelState.LOADED) + && !mlModelState.equals(MLModelState.LOADING) + && !mlModelState.equals(MLModelState.PARTIALLY_LOADED) + && !mlModelState.equals(MLModelState.DEPLOYED) + && !mlModelState.equals(MLModelState.DEPLOYING) + && !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED); } } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index b2c912b9d5..b1096e7e38 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -19,6 +20,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; @@ -30,7 +32,6 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.AbstractConnector; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.builder.SearchSourceBuilder; @@ -100,11 +101,11 @@ public void getConnector(Client client, String connectorId, ActionListener { - log.error("Fail to get connector", e); - listener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); + log.error("Failed to get connector", e); + listener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 9169595602..0679f2d855 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -5,7 +5,6 @@ package org.opensearch.ml.model; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import java.time.Instant; @@ -13,6 +12,7 @@ import java.util.Iterator; import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -25,10 +25,8 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -38,7 +36,6 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; -import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.search.SearchHit; @@ -220,22 +217,12 @@ public void validateUniqueModelGroupName(String name, ActionListener listener) { + public void getModelGroupResponse(String modelGroupId, ActionListener listener) { GetRequest getRequest = new GetRequest(); getRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelGroup mlModelGroup = MLModelGroup.parse(parser); - listener.onResponse(mlModelGroup); - } catch (Exception e) { - log.error("Failed to parse ml model group.", e); - listener.onFailure(e); - } + listener.onResponse(r); } else { listener.onFailure(new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId)); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index ea6dffbe46..03344e0eea 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -763,7 +763,6 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES, MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES, MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED, - MLCommonsSettings.ML_COMMONS_UPDATE_CONNECTOR_ENABLED, MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED ); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java index 3380d79dbc..6779238771 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -8,7 +8,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; -import static org.opensearch.ml.utils.MLExceptionUtils.UPDATE_CONNECTOR_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @@ -58,9 +57,7 @@ private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOExcept if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } - if (!mlFeatureEnabledSetting.isUpdateConnectorEnabled()) { - throw new IllegalStateException(UPDATE_CONNECTOR_DISABLED_ERR_MSG); - } + if (!request.hasContent()) { throw new IOException("Failed to update connector: Request body is empty"); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java index f653da5bb0..12e5e8c673 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Locale; +import org.opensearch.OpenSearchParseException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.model.MLUpdateModelAction; @@ -54,17 +55,21 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException { if (!request.hasContent()) { - throw new IOException("Model update request has empty body"); + throw new OpenSearchParseException("Model update request has empty body"); } String modelId = getParameterId(request, PARAMETER_MODEL_ID); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLUpdateModelInput input = MLUpdateModelInput.parse(parser); - // Model ID can only be set here. Model version can only be set automatically. - input.setModelId(modelId); - input.setVersion(null); - return new MLUpdateModelRequest(input); + try { + MLUpdateModelInput input = MLUpdateModelInput.parse(parser); + // Model ID can only be set here. Model version can only be set automatically. + input.setModelId(modelId); + input.setVersion(null); + return new MLUpdateModelRequest(input); + } catch (IllegalStateException e) { + throw new OpenSearchParseException(e.getMessage()); + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index a9d0f646ac..ee4962dadd 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -113,9 +113,6 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_REMOTE_INFERENCE_ENABLED = Setting .boolSetting("plugins.ml_commons.remote_inference.enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_UPDATE_CONNECTOR_ENABLED = Setting - .boolSetting("plugins.ml_commons.update_connector.enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index 8f231a061e..0a1a00ac4d 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -8,7 +8,6 @@ package org.opensearch.ml.settings; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_UPDATE_CONNECTOR_ENABLED; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -16,18 +15,13 @@ public class MLFeatureEnabledSetting { private volatile Boolean isRemoteInferenceEnabled; - private volatile Boolean isUpdateConnectorEnabled; public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); - isUpdateConnectorEnabled = ML_COMMONS_UPDATE_CONNECTOR_ENABLED.get(settings); clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_REMOTE_INFERENCE_ENABLED, it -> isRemoteInferenceEnabled = it); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(ML_COMMONS_UPDATE_CONNECTOR_ENABLED, it -> isUpdateConnectorEnabled = it); } /** @@ -38,8 +32,4 @@ public boolean isRemoteInferenceEnabled() { return isRemoteInferenceEnabled; } - public boolean isUpdateConnectorEnabled() { - return isUpdateConnectorEnabled; - } - } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 6f051615e3..da42d95382 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -20,8 +20,6 @@ public class MLExceptionUtils { public static final String NOT_SERIALIZABLE_EXCEPTION_WRAPPER = "NotSerializableExceptionWrapper: "; public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG = "Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true."; - public static final String UPDATE_CONNECTOR_DISABLED_ERR_MSG = - "Update connector API is currently disabled. To enable it, update the setting \"plugins.ml_commons.update_connector.enabled\" to true."; public static String getRootCauseMessage(final Throwable throwable) { String message = ExceptionUtils.getRootCauseMessage(throwable); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index a7fb34a4b5..c2cbf81cf1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -150,7 +150,7 @@ public void testGetConnector_IndexNotFoundException() { getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Fail to find connector", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to find connector", argumentCaptor.getValue().getMessage()); } public void testGetConnector_RuntimeException() { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java similarity index 99% rename from plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index f2ded6db10..df630b90b0 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -62,7 +62,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { +public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private UpdateConnectorTransportAction transportUpdateConnectorAction; diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 11fcc75b10..85dfaa552a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -16,6 +16,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Arrays; import org.junit.Before; import org.junit.Rule; @@ -24,22 +25,30 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLResourceNotFoundException; -import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; @@ -80,9 +89,6 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Mock MLModel mockModel; - @Mock - MLModelGroup mockModelGroup; - @Mock MLModelManager mlModelManager; @@ -114,6 +120,8 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { MLModel mlModelWithNullFunctionName; + MLModel localModel; + ThreadContext threadContext; @Before @@ -161,7 +169,7 @@ public void setup() throws IOException { ) ); - MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); @@ -202,11 +210,24 @@ public void setup() throws IOException { return null; }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("updated_test_model_group_id") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockModelGroup); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); return null; - }).when(mlModelGroupManager).getModelGroup(eq("updated_test_model_group_id"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); } @Test @@ -231,7 +252,7 @@ public void testUpdateModelStateLoadedException() { doReturn(MLModelState.LOADED).when(mockModel).getModelState(); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "ML Model mockId is in deploying or deployed state, please undeploy the models first!", @@ -255,7 +276,7 @@ public void testUpdateModelStateLoadingException() { doReturn(MLModelState.LOADING).when(mockModel).getModelState(); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "ML Model mockId is in deploying or deployed state, please undeploy the models first!", @@ -279,7 +300,7 @@ public void testUpdateModelStatePartiallyLoadedException() { doReturn(MLModelState.PARTIALLY_LOADED).when(mockModel).getModelState(); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "ML Model mockId is in deploying or deployed state, please undeploy the models first!", @@ -303,7 +324,7 @@ public void testUpdateModelStateDeployedException() { doReturn(MLModelState.DEPLOYED).when(mockModel).getModelState(); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "ML Model mockId is in deploying or deployed state, please undeploy the models first!", @@ -327,7 +348,7 @@ public void testUpdateModelStateDeployingException() { doReturn(MLModelState.DEPLOYING).when(mockModel).getModelState(); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "ML Model mockId is in deploying or deployed state, please undeploy the models first!", @@ -351,7 +372,7 @@ public void testUpdateModelStatePartiallyDeployedException() { doReturn(MLModelState.PARTIALLY_DEPLOYED).when(mockModel).getModelState(); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "ML Model mockId is in deploying or deployed state, please undeploy the models first!", @@ -549,7 +570,7 @@ public void testUpdateModelWithRegisterToNewModelGroupNotFound() { ActionListener listener = invocation.getArgument(1); listener.onFailure(new MLResourceNotFoundException("Model group not found with MODEL_GROUP_ID: updated_test_model_group_id")); return null; - }).when(mlModelGroupManager).getModelGroup(eq("updated_test_model_group_id"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -658,11 +679,24 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("mockUpdateModelGroupId"), any(), isA(ActionListener.class)); + MLModelGroup modelGroup = MLModelGroup + .builder() + .modelGroupId("updated_test_model_group_id") + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); + + GetResponse getResponse = prepareGetResponse(modelGroup); + doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mockModelGroup); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); return null; - }).when(mlModelGroupManager).getModelGroup(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); @@ -741,6 +775,8 @@ public void testGetUpdateResponseListenerOtherException() { ); } + // TODO: Add UT to make sure that version incremented successfully. + private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgumentException { MLModel mlModel; switch (functionName) { @@ -801,4 +837,11 @@ private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws I throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS"); } } + + private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOException { + XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 8c3db0c054..108ab84b16 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -22,6 +22,7 @@ import org.junit.Before; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -206,7 +207,7 @@ public void test_validateConnectorAccess_connectorNotFound_return_false() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); - verify(actionListener, times(1)).onFailure(any(MLResourceNotFoundException.class)); + verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } public void test_validateConnectorAccess_searchConnectorException_return_false() { @@ -221,7 +222,7 @@ public void test_validateConnectorAccess_searchConnectorException_return_false() threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); - verify(actionListener).onFailure(any(IllegalStateException.class)); + verify(actionListener).onFailure(any(OpenSearchStatusException.class)); } public void test_skipConnectorAccessControl_userIsNull_return_true() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 4d40b8586a..ce01b44026 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -70,7 +70,7 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { private ActionListener actionListener; @Mock - private ActionListener modelGroupListener; + private ActionListener modelGroupListener; @Mock private IndexResponse indexResponse; @@ -353,9 +353,8 @@ public void test_SuccessGetModelGroup() throws IOException { return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - mlModelGroupManager.getModelGroup("testModelGroupID", modelGroupListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGroup.class); - verify(modelGroupListener).onResponse(argumentCaptor.capture()); + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + verify(modelGroupListener).onResponse(getResponse); } public void test_OtherExceptionGetModelGroup() throws IOException { @@ -368,7 +367,7 @@ public void test_OtherExceptionGetModelGroup() throws IOException { return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - mlModelGroupManager.getModelGroup("testModelGroupID", modelGroupListener); + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(modelGroupListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -384,7 +383,7 @@ public void test_NotFoundGetModelGroup() throws IOException { return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - mlModelGroupManager.getModelGroup("testModelGroupID", modelGroupListener); + mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(modelGroupListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model group with ID: testModelGroupID", argumentCaptor.getValue().getMessage()); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java index 3bc5a5e940..9be0d518ae 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -67,7 +67,6 @@ public void setup() { client = spy(new NodeClient(Settings.EMPTY, threadPool)); when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); - when(mlFeatureEnabledSetting.isUpdateConnectorEnabled()).thenReturn(true); restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java index e4511df9dc..28687d1c9c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -12,7 +12,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -24,6 +23,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; @@ -111,7 +111,7 @@ public void testUpdateModelRequest() throws Exception { @Test public void testUpdateModelRequestWithEmptyContent() throws Exception { - exceptionRule.expect(IOException.class); + exceptionRule.expect(OpenSearchParseException.class); exceptionRule.expectMessage("Model update request has empty body"); RestRequest request = getRestRequestWithEmptyContent(); restMLUpdateModelAction.handleRequest(request, channel, client); @@ -125,6 +125,14 @@ public void testUpdateModelRequestWithNullModelId() throws Exception { restMLUpdateModelAction.handleRequest(request, channel, client); } + @Test + public void testUpdateModelRequestWithNullField() throws Exception { + exceptionRule.expect(OpenSearchParseException.class); + exceptionRule.expectMessage("Can't get text on a VALUE_NULL"); + RestRequest request = getRestRequestWithNullField(); + restMLUpdateModelAction.handleRequest(request, channel, client); + } + private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.PUT; final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); @@ -166,4 +174,18 @@ private RestRequest getRestRequestWithNullModelId() { .build(); return request; } + + private RestRequest getRestRequestWithNullField() { + RestRequest.Method method = RestRequest.Method.PUT; + String requestContent = "{\"name\":\"testModelName\",\"description\":null}"; + Map params = new HashMap<>(); + params.put("model_id", "test_modelId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } }