From 78af78a685b23439276ce85c0deaa2e2a0eb014d Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 3 Nov 2023 18:25:40 +0000 Subject: [PATCH 1/5] Manual Backport of #1388 Signed-off-by: Joshua Palis --- .../ml/client/MachineLearningClient.java | 41 ++++++++++ .../ml/client/MachineLearningNodeClient.java | 23 ++++++ .../ml/client/MachineLearningClientTest.java | 50 ++++++++++++ .../client/MachineLearningNodeClientTest.java | 80 +++++++++++++++++++ 4 files changed, 194 insertions(+) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index f13a3e4f7b..39de4510d7 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -17,6 +17,9 @@ import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import java.util.List; import java.util.Map; @@ -229,6 +232,44 @@ default ActionFuture searchTask(SearchRequest searchRequest) { */ void searchTask(SearchRequest searchRequest, ActionListener listener); + /** + * Register model + * For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model + * @param mlInput ML input + */ + default ActionFuture register(MLRegisterModelInput mlInput) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + register(mlInput, actionFuture); + return actionFuture; + } + + /** + * Register model + * For additional info on register, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#registering-a-model + * @param mlInput ML input + * @param listener a listener to be notified of the result + */ + void register(MLRegisterModelInput mlInput, ActionListener listener); + + /** + * Deploy model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model + * @param modelId the model id + */ + default ActionFuture deploy(String modelId) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + deploy(modelId, actionFuture); + return actionFuture; + } + + /** + * Deploy model + * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model + * @param modelId the model id + * @param listener a listener to be notified of the result + */ + void deploy(String modelId, ActionListener listener); + /** * Get a list of ToolMetadata and return ActionFuture. * For more info on list tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index d594ef03d2..69147ee90a 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -22,6 +22,9 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -30,6 +33,10 @@ import org.opensearch.ml.common.transport.model.MLModelSearchAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelAction; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.task.MLTaskDeleteAction; import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest; import org.opensearch.ml.common.transport.task.MLTaskGetAction; @@ -197,6 +204,22 @@ public void searchTask(SearchRequest searchRequest, ActionListener listener) { + MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput); + client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> { + listener.onFailure(e); + })); + } + + @Override + public void deploy(String modelId, ActionListener listener) { + MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false); + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> { + listener.onFailure(e); + })); + } + @Override public void listTools(ActionListener> listener) { MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().build(); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index cc640ac22d..0facc316cc 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -21,10 +21,16 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import java.util.HashMap; import java.util.List; @@ -60,6 +66,13 @@ public class MachineLearningClientTest { @Mock SearchResponse searchResponse; + @Mock + MLRegisterModelResponse registerModelResponse; + + @Mock + MLDeployModelResponse deployModelResponse; + + private String modekId = "test_model_id"; private MLModel mlModel; private MLTask mlTask; @@ -135,6 +148,16 @@ public void searchTask(SearchRequest searchRequest, ActionListener listener) { + listener.onResponse(registerModelResponse); + } + + @Override + public void deploy(String modelId, ActionListener listener) { + listener.onResponse(deployModelResponse); + } + @Override public void listTools(ActionListener> listener) { listener.onResponse(null); @@ -263,4 +286,31 @@ public void deleteTask() { public void searchTask() { assertEquals(searchResponse, machineLearningClient.searchTask(new SearchRequest()).actionGet()); } + + @Test + public void register() { + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(FunctionName.KMEANS) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + assertEquals(registerModelResponse, machineLearningClient.register(mlInput).actionGet()); + } + + @Test + public void deploy() { + assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 2f6f11998d..fc26c27ece 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -31,15 +31,22 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetAction; @@ -48,6 +55,10 @@ import org.opensearch.ml.common.transport.model.MLModelSearchAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelAction; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.ml.common.transport.task.MLTaskDeleteAction; import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest; import org.opensearch.ml.common.transport.task.MLTaskGetAction; @@ -120,6 +131,13 @@ public class MachineLearningNodeClientTest { @Mock ActionListener searchTaskActionListener; + @Mock + ActionListener RegisterModelActionListener; + + @Mock + ActionListener DeployModelActionListener; + + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -565,6 +583,68 @@ public void searchTask() { assertEquals(modelId, source.get(MLTask.MODEL_ID_FIELD)); } + @Test + public void register() { + String taskId = "taskId"; + String status = MLTaskState.CREATED.name(); + FunctionName functionName = FunctionName.KMEANS; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLRegisterModelAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + machineLearningNodeClient.register(mlInput, RegisterModelActionListener); + + verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any()); + verify(RegisterModelActionListener).onResponse(argumentCaptor.capture()); + assertEquals(taskId, (argumentCaptor.getValue()).getTaskId()); + assertEquals(status, (argumentCaptor.getValue()).getStatus()); + } + + @Test + public void deploy() { + String taskId = "taskId"; + String status = MLTaskState.CREATED.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; + String modelId = "modelId"; + FunctionName functionName = FunctionName.KMEANS; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class); + machineLearningNodeClient.deploy(modelId, DeployModelActionListener); + + verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any()); + verify(DeployModelActionListener).onResponse(argumentCaptor.capture()); + assertEquals(taskId, (argumentCaptor.getValue()).getTaskId()); + assertEquals(status, (argumentCaptor.getValue()).getStatus()); + } + + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); From 5cbec16d380b4dc14026ba6345beff8004cac078 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 3 Nov 2023 18:39:16 +0000 Subject: [PATCH 2/5] Manual backport of #1437 Signed-off-by: Joshua Palis --- .../ml/client/MachineLearningClient.java | 15 +++++ .../ml/client/MachineLearningNodeClient.java | 10 +++ .../ml/client/MachineLearningClientTest.java | 33 ++++++++++ .../client/MachineLearningNodeClientTest.java | 64 +++++++++++++++++-- 4 files changed, 115 insertions(+), 7 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 39de4510d7..e7d01c6b14 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -17,6 +17,8 @@ import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; @@ -270,6 +272,19 @@ default ActionFuture deploy(String modelId) { */ void deploy(String modelId, ActionListener listener); + /** + * Create connector for remote model + * @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/ + * @return the result future + */ + default ActionFuture createConnector(MLCreateConnectorInput mlCreateConnectorInput) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + createConnector(mlCreateConnectorInput, actionFuture); + return actionFuture; + } + + void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener); + /** * Get a list of ToolMetadata and return ActionFuture. * For more info on list tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 69147ee90a..949684c078 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -22,6 +22,10 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -220,6 +224,12 @@ public void deploy(String modelId, ActionListener listene })); } + @Override + public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { + MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput); + client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, listener); + } + @Override public void listTools(ActionListener> listener) { MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().build(); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 0facc316cc..34039e46f6 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -19,6 +19,7 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.model.MLModelConfig; @@ -28,6 +29,8 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.output.MLTrainingOutput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; @@ -72,6 +75,9 @@ public class MachineLearningClientTest { @Mock MLDeployModelResponse deployModelResponse; + @Mock + MLCreateConnectorResponse createConnectorResponse; + private String modekId = "test_model_id"; private MLModel mlModel; @@ -158,6 +164,11 @@ public void deploy(String modelId, ActionListener listene listener.onResponse(deployModelResponse); } + @Override + public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { + listener.onResponse(createConnectorResponse); + } + @Override public void listTools(ActionListener> listener) { listener.onResponse(null); @@ -313,4 +324,26 @@ public void register() { public void deploy() { assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet()); } + + @Test + public void createConnector() { + Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); + Map credentials = Map.ofEntries(Map.entry("key1", "key1"), Map.entry("key2", "key2")); + + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder() + .name("test") + .description("description") + .version("testModelVersion") + .protocol("testProtocol") + .parameters(params) + .credential(credentials) + .actions(null) + .backendRoles(null) + .addAllBackendRoles(false) + .access(AccessMode.from("private")) + .dryRun(false) + .build(); + + assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet()); + } } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index fc26c27ece..8ae7183f41 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -27,6 +27,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; @@ -44,6 +45,10 @@ import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -76,8 +81,10 @@ import org.opensearch.search.suggest.Suggest; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; @@ -132,11 +139,13 @@ public class MachineLearningNodeClientTest { ActionListener searchTaskActionListener; @Mock - ActionListener RegisterModelActionListener; + ActionListener registerModelActionListener; @Mock - ActionListener DeployModelActionListener; + ActionListener deployModelActionListener; + @Mock + ActionListener createConnectorActionListener; @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -613,10 +622,10 @@ public void register() { .deployModel(true) .modelNodeIds(new String[]{"modelNodeIds" }) .build(); - machineLearningNodeClient.register(mlInput, RegisterModelActionListener); + machineLearningNodeClient.register(mlInput, registerModelActionListener); verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any()); - verify(RegisterModelActionListener).onResponse(argumentCaptor.capture()); + verify(registerModelActionListener).onResponse(argumentCaptor.capture()); assertEquals(taskId, (argumentCaptor.getValue()).getTaskId()); assertEquals(status, (argumentCaptor.getValue()).getStatus()); } @@ -627,7 +636,6 @@ public void deploy() { String status = MLTaskState.CREATED.name(); MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; String modelId = "modelId"; - FunctionName functionName = FunctionName.KMEANS; doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); @@ -636,14 +644,56 @@ public void deploy() { }).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLDeployModelResponse.class); - machineLearningNodeClient.deploy(modelId, DeployModelActionListener); + machineLearningNodeClient.deploy(modelId, deployModelActionListener); verify(client).execute(eq(MLDeployModelAction.INSTANCE), isA(MLDeployModelRequest.class), any()); - verify(DeployModelActionListener).onResponse(argumentCaptor.capture()); + verify(deployModelActionListener).onResponse(argumentCaptor.capture()); assertEquals(taskId, (argumentCaptor.getValue()).getTaskId()); assertEquals(status, (argumentCaptor.getValue()).getStatus()); } + @Test + public void createConnector() { + + + String connectorId = "connectorId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateConnectorResponse.class); + + Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); + Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); + List backendRoles = Arrays.asList("IT", "HR"); + + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder() + .name("test") + .description("description") + .version("testModelVersion") + .protocol("testProtocol") + .parameters(params) + .credential(credentials) + .actions(null) + .backendRoles(backendRoles) + .addAllBackendRoles(false) + .access(AccessMode.from("private")) + .dryRun(false) + .build(); + + machineLearningNodeClient.createConnector(mlCreateConnectorInput, createConnectorActionListener); + + verify(client).execute(eq(MLCreateConnectorAction.INSTANCE), isA(MLCreateConnectorRequest.class), any()); + verify(createConnectorActionListener).onResponse(argumentCaptor.capture()); + assertEquals(connectorId, (argumentCaptor.getValue()).getConnectorId()); + + } + + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); From 9885fe79b25e1cd912cadd86f6bddc5f27be29ea Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 3 Nov 2023 19:20:12 +0000 Subject: [PATCH 3/5] Manual backport of 1493 Signed-off-by: Joshua Palis --- client/build.gradle | 11 + .../ml/client/MachineLearningClient.java | 33 +- .../ml/client/MachineLearningNodeClient.java | 144 ++++--- .../opensearch/ml/client/package-info.java | 2 +- .../ml/client/MachineLearningClientTest.java | 217 ++++++----- .../client/MachineLearningNodeClientTest.java | 363 +++++++++--------- .../opensearch/ml/common/MLModelGroup.java | 6 +- .../MLRegisterModelGroupInput.java | 6 +- .../ml/common/MLModelGroupTest.java | 4 +- 9 files changed, 414 insertions(+), 372 deletions(-) diff --git a/client/build.gradle b/client/build.gradle index 57d4669f60..5f6bca9014 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -9,6 +9,7 @@ plugins { id 'jacoco' id 'com.github.johnrengelman.shadow' id 'maven-publish' + id 'com.diffplug.spotless' version '6.18.0' id 'signing' } @@ -21,6 +22,16 @@ dependencies { } +spotless { + java { + removeUnusedImports() + importOrder 'java', 'javax', 'org', 'com' + + eclipse().configFile rootProject.file('.eclipseformat.xml') + } +} + + jacocoTestReport { reports { xml.getRequired().set(true) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index e7d01c6b14..4d81448362 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -5,13 +5,15 @@ package org.opensearch.ml.client; +import java.util.List; +import java.util.Map; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; @@ -20,12 +22,11 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.List; -import java.util.Map; - /** * A client to provide interfaces for machine learning jobs. This will be used by other plugins. */ @@ -86,7 +87,6 @@ default ActionFuture train(MLInput mlInput, boolean asyncTask) { return actionFuture; } - /** * Do the training machine learning job. The training job will be always async process. The job id will be returned in this method. * For more info on train model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#train-model @@ -207,7 +207,6 @@ default ActionFuture searchModel(SearchRequest searchRequest) { return actionFuture; } - /** * For more info on search model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-model * @param searchRequest searchRequest to search the ML Model @@ -215,7 +214,6 @@ default ActionFuture searchModel(SearchRequest searchRequest) { */ void searchModel(SearchRequest searchRequest, ActionListener listener); - /** * For more info on search task, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#search-task * @param searchRequest searchRequest to search the ML Task @@ -285,6 +283,25 @@ default ActionFuture createConnector(MLCreateConnecto void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener); + /** + * Register model group + * For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group + * @param mlRegisterModelGroupInput model group input + */ + default ActionFuture registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + registerModelGroup(mlRegisterModelGroupInput, actionFuture); + return actionFuture; + } + + /** + * Register model group + * For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group + * @param mlRegisterModelGroupInput model group input + * @param listener a listener to be notified of the result + */ + void registerModelGroup(MLRegisterModelGroupInput mlRegisterModelGroupInput, ActionListener listener); + /** * Get a list of ToolMetadata and return ActionFuture. * For more info on list tools, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#list-tools diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 949684c078..6828c111a2 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -5,15 +5,25 @@ package org.opensearch.ml.client; -import lombok.AccessLevel; -import lombok.RequiredArgsConstructor; -import lombok.experimental.FieldDefaults; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.action.ActionResponse; +import static org.opensearch.ml.common.input.Constants.ASYNC; +import static org.opensearch.ml.common.input.Constants.MODELID; +import static org.opensearch.ml.common.input.Constants.PREDICT; +import static org.opensearch.ml.common.input.Constants.TRAIN; +import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT; +import static org.opensearch.ml.common.input.InputHelper.convertArgumentToMLParameter; +import static org.opensearch.ml.common.input.InputHelper.getAction; +import static org.opensearch.ml.common.input.InputHelper.getFunctionName; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; + import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; @@ -35,6 +45,10 @@ import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; @@ -57,18 +71,9 @@ import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction; -import java.util.List; -import java.util.Map; -import java.util.function.Function; - -import static org.opensearch.ml.common.input.Constants.ASYNC; -import static org.opensearch.ml.common.input.Constants.MODELID; -import static org.opensearch.ml.common.input.Constants.PREDICT; -import static org.opensearch.ml.common.input.Constants.TRAIN; -import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT; -import static org.opensearch.ml.common.input.InputHelper.convertArgumentToMLParameter; -import static org.opensearch.ml.common.input.InputHelper.getAction; -import static org.opensearch.ml.common.input.InputHelper.getFunctionName; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.experimental.FieldDefaults; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor @@ -80,21 +85,19 @@ public class MachineLearningNodeClient implements MachineLearningClient { public void predict(String modelId, MLInput mlInput, ActionListener listener) { validateMLInput(mlInput, true); - MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .modelId(modelId) - .dispatchTask(true) - .build(); + MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest + .builder() + .mlInput(mlInput) + .modelId(modelId) + .dispatchTask(true) + .build(); client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener)); } @Override public void trainAndPredict(MLInput mlInput, ActionListener listener) { validateMLInput(mlInput, true); - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .dispatchTask(true) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).dispatchTask(true).build(); client.execute(MLTrainAndPredictionTaskAction.INSTANCE, request, getMlPredictionTaskResponseActionListener(listener)); } @@ -102,11 +105,12 @@ public void trainAndPredict(MLInput mlInput, ActionListener listener) @Override public void train(MLInput mlInput, boolean asyncTask, ActionListener listener) { validateMLInput(mlInput, true); - MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .async(asyncTask) - .dispatchTask(true) - .build(); + MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest + .builder() + .mlInput(mlInput) + .async(asyncTask) + .dispatchTask(true) + .build(); client.execute(MLTrainingTaskAction.INSTANCE, trainingTaskRequest, getMlPredictionTaskResponseActionListener(listener)); } @@ -136,15 +140,13 @@ public void run(MLInput mlInput, Map args, ActionListener listener) { - MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder() - .modelId(modelId) - .build(); + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener)); } @@ -162,9 +164,7 @@ private ActionListener getMlGetModelResponseActionListener(A @Override public void deleteModel(String modelId, ActionListener listener) { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId) - .build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(deleteResponse -> { listener.onResponse(deleteResponse); @@ -173,17 +173,26 @@ public void deleteModel(String modelId, ActionListener listener) @Override public void searchModel(SearchRequest searchRequest, ActionListener listener) { - client.execute(MLModelSearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> { - listener.onResponse(searchResponse); - }, listener::onFailure)); + client + .execute( + MLModelSearchAction.INSTANCE, + searchRequest, + ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure) + ); } + @Override + public void registerModelGroup( + MLRegisterModelGroupInput mlRegisterModelGroupInput, + ActionListener listener + ) { + MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest(mlRegisterModelGroupInput); + client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener); + } @Override public void getTask(String taskId, ActionListener listener) { - MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder() - .taskId(taskId) - .build(); + MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build(); client.execute(MLTaskGetAction.INSTANCE, mlTaskGetRequest, ActionListener.wrap(response -> { listener.onResponse(MLTaskGetResponse.fromActionResponse(response).getMlTask()); @@ -192,9 +201,7 @@ public void getTask(String taskId, ActionListener listener) { @Override public void deleteTask(String taskId, ActionListener listener) { - MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder() - .taskId(taskId) - .build(); + MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.builder().taskId(taskId).build(); client.execute(MLTaskDeleteAction.INSTANCE, mlTaskDeleteRequest, ActionListener.wrap(deleteResponse -> { listener.onResponse(deleteResponse); @@ -203,25 +210,34 @@ public void deleteTask(String taskId, ActionListener listener) { @Override public void searchTask(SearchRequest searchRequest, ActionListener listener) { - client.execute(MLTaskSearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> { - listener.onResponse(searchResponse); - }, listener::onFailure)); + client + .execute( + MLTaskSearchAction.INSTANCE, + searchRequest, + ActionListener.wrap(searchResponse -> { listener.onResponse(searchResponse); }, listener::onFailure) + ); } @Override public void register(MLRegisterModelInput mlInput, ActionListener listener) { MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput); - client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> { - listener.onFailure(e); - })); + client + .execute( + MLRegisterModelAction.INSTANCE, + registerRequest, + ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) + ); } @Override public void deploy(String modelId, ActionListener listener) { MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false); - client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(listener::onResponse, e -> { - listener.onFailure(e); - })); + client + .execute( + MLDeployModelAction.INSTANCE, + deployModelRequest, + ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) + ); } @Override @@ -277,12 +293,14 @@ private ActionListener getMlPredictionTaskResponseActionListener return actionListener; } - private ActionListener wrapActionListener(final ActionListener listener, final Function recreate) { - ActionListener actionListener = ActionListener.wrap(r-> { - listener.onResponse(recreate.apply(r));; - }, e->{ - listener.onFailure(e); - }); + private ActionListener wrapActionListener( + final ActionListener listener, + final Function recreate + ) { + ActionListener actionListener = ActionListener.wrap(r -> { + listener.onResponse(recreate.apply(r)); + ; + }, e -> { listener.onFailure(e); }); return actionListener; } @@ -290,7 +308,7 @@ private void validateMLInput(MLInput mlInput, boolean requireInput) { if (mlInput == null) { throw new IllegalArgumentException("ML Input can't be null"); } - if(requireInput && mlInput.getInputDataset() == null) { + if (requireInput && mlInput.getInputDataset() == null) { throw new IllegalArgumentException("input data set can't be null"); } } diff --git a/client/src/main/java/org/opensearch/ml/client/package-info.java b/client/src/main/java/org/opensearch/ml/client/package-info.java index cea2d5c387..75ba6b7b9f 100644 --- a/client/src/main/java/org/opensearch/ml/client/package-info.java +++ b/client/src/main/java/org/opensearch/ml/client/package-info.java @@ -3,4 +3,4 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.client; \ No newline at end of file +package org.opensearch.ml.client; diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 34039e46f6..4b137ac685 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -5,50 +5,51 @@ package org.opensearch.ml.client; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.input.Constants.ACTION; +import static org.opensearch.ml.common.input.Constants.ALGORITHM; +import static org.opensearch.ml.common.input.Constants.KMEANS; +import static org.opensearch.ml.common.input.Constants.TRAIN; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.ToolMetadata; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.AccessMode; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.output.MLOutput; -import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.verify; -import static org.opensearch.ml.common.input.Constants.ACTION; -import static org.opensearch.ml.common.input.Constants.ALGORITHM; -import static org.opensearch.ml.common.input.Constants.KMEANS; -import static org.opensearch.ml.common.input.Constants.TRAIN; - public class MachineLearningClientTest { - MachineLearningClient machineLearningClient; @Mock @@ -78,6 +79,8 @@ public class MachineLearningClientTest { @Mock MLCreateConnectorResponse createConnectorResponse; + @Mock + MLRegisterModelGroupResponse registerModelGroupResponse; private String modekId = "test_model_id"; private MLModel mlModel; @@ -88,24 +91,14 @@ public void setUp() { MockitoAnnotations.openMocks(this); String taskId = "taskId"; String modelId = "modelId"; - mlTask = MLTask.builder() - .taskId(taskId) - .modelId(modelId) - .functionName(FunctionName.KMEANS) - .build(); + mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build(); String modelContent = "test content"; - mlModel = MLModel.builder() - .algorithm(FunctionName.KMEANS) - .name("test") - .content(modelContent) - .build(); + mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build(); machineLearningClient = new MachineLearningClient() { @Override - public void predict(String modelId, - MLInput mlInput, - ActionListener listener) { + public void predict(String modelId, MLInput mlInput, ActionListener listener) { listener.onResponse(output); } @@ -169,6 +162,13 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio listener.onResponse(createConnectorResponse); } + public void registerModelGroup( + MLRegisterModelGroupInput mlRegisterModelGroupInput, + ActionListener listener + ) { + listener.onResponse(registerModelGroupResponse); + } + @Override public void listTools(ActionListener> listener) { listener.onResponse(null); @@ -183,39 +183,35 @@ public void getTool(String toolName, ActionListener listener) { @Test public void predict_WithAlgoAndInputData() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .inputDataset(new DataFrameInputDataset(input)) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build(); assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet()); } @Test public void predict_WithAlgoAndParametersAndInputData() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(mlParameters) - .inputDataset(new DataFrameInputDataset(input)) - .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(mlParameters) + .inputDataset(new DataFrameInputDataset(input)) + .build(); assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet()); } @Test public void predict_WithAlgoAndParametersAndInputDataAndModelId() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(mlParameters) - .inputDataset(new DataFrameInputDataset(input)) - .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(mlParameters) + .inputDataset(new DataFrameInputDataset(input)) + .build(); assertEquals(output, machineLearningClient.predict("modelId", mlInput).actionGet()); } @Test public void predict_WithAlgoAndInputDataAndListener() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .inputDataset(new DataFrameInputDataset(input)) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(new DataFrameInputDataset(input)).build(); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); machineLearningClient.predict(null, mlInput, dataFrameActionListener); verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); @@ -224,11 +220,12 @@ public void predict_WithAlgoAndInputDataAndListener() { @Test public void predict_WithAlgoAndInputDataAndParametersAndListener() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(mlParameters) - .inputDataset(new DataFrameInputDataset(input)) - .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(mlParameters) + .inputDataset(new DataFrameInputDataset(input)) + .build(); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); machineLearningClient.predict(null, mlInput, dataFrameActionListener); verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); @@ -237,31 +234,34 @@ public void predict_WithAlgoAndInputDataAndParametersAndListener() { @Test public void train() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(mlParameters) - .inputDataset(new DataFrameInputDataset(input)) - .build(); - assertEquals(modekId, ((MLTrainingOutput)machineLearningClient.train(mlInput, false).actionGet()).getModelId()); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(mlParameters) + .inputDataset(new DataFrameInputDataset(input)) + .build(); + assertEquals(modekId, ((MLTrainingOutput) machineLearningClient.train(mlInput, false).actionGet()).getModelId()); } @Test public void trainAndPredict() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(mlParameters) - .inputDataset(new DataFrameInputDataset(input)) - .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(mlParameters) + .inputDataset(new DataFrameInputDataset(input)) + .build(); assertEquals(output, machineLearningClient.trainAndPredict(mlInput).actionGet()); } @Test public void execute() { - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.SAMPLE_ALGO) - .parameters(mlParameters) - .inputDataset(new DataFrameInputDataset(input)) - .build(); + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.SAMPLE_ALGO) + .parameters(mlParameters) + .inputDataset(new DataFrameInputDataset(input)) + .build(); Map args = new HashMap<>(); args.put(ACTION, TRAIN); args.put(ALGORITHM, KMEANS); @@ -283,6 +283,22 @@ public void searchModel() { assertEquals(searchResponse, machineLearningClient.searchModel(new SearchRequest()).actionGet()); } + @Test + public void registerModelGroup() { + List backendRoles = Arrays.asList("IT", "HR"); + + MLRegisterModelGroupInput mlRegisterModelGroupInput = MLRegisterModelGroupInput + .builder() + .name("test") + .description("description") + .backendRoles(backendRoles) + .modelAccessMode(AccessMode.from("public")) + .isAddAllBackendRoles(false) + .build(); + + assertEquals(registerModelGroupResponse, machineLearningClient.registerModelGroup(mlRegisterModelGroupInput).actionGet()); + } + @Test public void getTask() { assertEquals(mlTask, machineLearningClient.getTask("taskId").actionGet()); @@ -300,23 +316,25 @@ public void searchTask() { @Test public void register() { - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - MLRegisterModelInput mlInput = MLRegisterModelInput.builder() - .functionName(FunctionName.KMEANS) - .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput mlInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.KMEANS) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); assertEquals(registerModelResponse, machineLearningClient.register(mlInput).actionGet()); } @@ -330,19 +348,20 @@ public void createConnector() { Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); Map credentials = Map.ofEntries(Map.entry("key1", "key1"), Map.entry("key2", "key2")); - MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder() - .name("test") - .description("description") - .version("testModelVersion") - .protocol("testProtocol") - .parameters(params) - .credential(credentials) - .actions(null) - .backendRoles(null) - .addAllBackendRoles(false) - .access(AccessMode.from("private")) - .dryRun(false) - .build(); + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput + .builder() + .name("test") + .description("description") + .version("testModelVersion") + .protocol("testProtocol") + .parameters(params) + .credential(credentials) + .actions(null) + .backendRoles(null) + .addAllBackendRoles(false) + .access(AccessMode.from("private")) + .dryRun(false) + .build(); assertEquals(createConnectorResponse, machineLearningClient.createConnector(mlCreateConnectorInput).actionGet()); } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 8ae7183f41..ccdf812195 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -5,6 +5,29 @@ package org.opensearch.ml.client; +import static org.junit.Assert.assertEquals; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.input.Constants.ACTION; +import static org.opensearch.ml.common.input.Constants.ALGORITHM; +import static org.opensearch.ml.common.input.Constants.KMEANS; +import static org.opensearch.ml.common.input.Constants.MODELID; +import static org.opensearch.ml.common.input.Constants.PREDICT; +import static org.opensearch.ml.common.input.Constants.RCF; +import static org.opensearch.ml.common.input.Constants.TRAIN; +import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; @@ -14,14 +37,14 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.ToXContent; @@ -41,8 +64,6 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; @@ -58,6 +79,10 @@ import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.common.transport.model.MLModelSearchAction; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; @@ -80,29 +105,6 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Answers.RETURNS_DEEP_STUBS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.verify; -import static org.opensearch.ml.common.input.Constants.ACTION; -import static org.opensearch.ml.common.input.Constants.ALGORITHM; -import static org.opensearch.ml.common.input.Constants.KMEANS; -import static org.opensearch.ml.common.input.Constants.MODELID; -import static org.opensearch.ml.common.input.Constants.PREDICT; -import static org.opensearch.ml.common.input.Constants.RCF; -import static org.opensearch.ml.common.input.Constants.TRAIN; -import static org.opensearch.ml.common.input.Constants.TRAINANDPREDICT; - public class MachineLearningNodeClientTest { @Mock(answer = RETURNS_DEEP_STUBS) @@ -147,6 +149,9 @@ public class MachineLearningNodeClientTest { @Mock ActionListener createConnectorActionListener; + @Mock + ActionListener registerModelGroupResponseActionListener; + @InjectMocks MachineLearningNodeClient machineLearningNodeClient; @@ -162,36 +167,30 @@ public void setUp() { public void predict() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLPredictionOutput predictionOutput = MLPredictionOutput.builder() - .status("Success") - .predictionResult(output) - .taskId("taskId") - .build(); - actionListener.onResponse(MLTaskResponse.builder() - .output(predictionOutput) - .build()); + MLPredictionOutput predictionOutput = MLPredictionOutput + .builder() + .status("Success") + .predictionResult(output) + .taskId("taskId") + .build(); + actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build()); return null; }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build(); machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener); verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any()); verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); - assertEquals(output, ((MLPredictionOutput)dataFrameArgumentCaptor.getValue()).getPredictionResult()); + assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult()); } @Test public void predict_Exception_WithNullAlgorithm() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("algorithm can't be null"); - MLInput mlInput = MLInput.builder() - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().inputDataset(input).build(); machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener); } @@ -199,9 +198,7 @@ public void predict_Exception_WithNullAlgorithm() { public void predict_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("input data set can't be null"); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); machineLearningNodeClient.predict(null, mlInput, dataFrameActionListener); } @@ -211,36 +208,26 @@ public void train() { String status = "InProgress"; doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLTrainingOutput output = MLTrainingOutput.builder() - .status(status) - .modelId(modelId) - .build(); - actionListener.onResponse(MLTaskResponse.builder() - .output(output) - .build()); + MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build(); + actionListener.onResponse(MLTaskResponse.builder().output(output).build()); return null; }).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLOutput.class); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build(); machineLearningNodeClient.train(mlInput, false, trainingActionListener); verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any()); verify(trainingActionListener).onResponse(argumentCaptor.capture()); - assertEquals(modelId, ((MLTrainingOutput)argumentCaptor.getValue()).getModelId()); - assertEquals(status, ((MLTrainingOutput)argumentCaptor.getValue()).getStatus()); + assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId()); + assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus()); } @Test public void train_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("input data set can't be null"); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); machineLearningNodeClient.train(mlInput, false, trainingActionListener); } @@ -255,28 +242,24 @@ public void train_Exception_WithNullInput() { public void trainAndPredict() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLPredictionOutput predictionOutput = MLPredictionOutput.builder() - .status(MLTaskState.COMPLETED.name()) - .predictionResult(output) - .taskId("taskId") - .build(); - actionListener.onResponse(MLTaskResponse.builder() - .output(predictionOutput) - .build()); + MLPredictionOutput predictionOutput = MLPredictionOutput + .builder() + .status(MLTaskState.COMPLETED.name()) + .predictionResult(output) + .taskId("taskId") + .build(); + actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build()); return null; }).when(client).execute(eq(MLTrainAndPredictionTaskAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLOutput.class); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build(); machineLearningNodeClient.trainAndPredict(mlInput, trainingActionListener); verify(client).execute(eq(MLTrainAndPredictionTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any()); verify(trainingActionListener).onResponse(argumentCaptor.capture()); - assertEquals(MLTaskState.COMPLETED.name(), ((MLPredictionOutput)argumentCaptor.getValue()).getStatus()); - assertEquals(output, ((MLPredictionOutput)argumentCaptor.getValue()).getPredictionResult()); + assertEquals(MLTaskState.COMPLETED.name(), ((MLPredictionOutput) argumentCaptor.getValue()).getStatus()); + assertEquals(output, ((MLPredictionOutput) argumentCaptor.getValue()).getPredictionResult()); } @Test @@ -301,27 +284,23 @@ public void execute_predict_null_model_id() { private void execute_predict(Map args) { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLPredictionOutput predictionOutput = MLPredictionOutput.builder() - .status("Success") - .predictionResult(output) - .taskId("taskId") - .build(); - actionListener.onResponse(MLTaskResponse.builder() - .output(predictionOutput) - .build()); + MLPredictionOutput predictionOutput = MLPredictionOutput + .builder() + .status("Success") + .predictionResult(output) + .taskId("taskId") + .build(); + actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build()); return null; }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.SAMPLE_ALGO) - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build(); machineLearningNodeClient.run(mlInput, args, dataFrameActionListener); verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any()); verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); - assertEquals(output, ((MLPredictionOutput)dataFrameArgumentCaptor.getValue()).getPredictionResult()); + assertEquals(output, ((MLPredictionOutput) dataFrameArgumentCaptor.getValue()).getPredictionResult()); } @Test @@ -330,13 +309,8 @@ public void execute_train() { String status = "InProgress"; doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLTrainingOutput output = MLTrainingOutput.builder() - .status(status) - .modelId(modelId) - .build(); - actionListener.onResponse(MLTaskResponse.builder() - .output(output) - .build()); + MLTrainingOutput output = MLTrainingOutput.builder().status(status).modelId(modelId).build(); + actionListener.onResponse(MLTaskResponse.builder().output(output).build()); return null; }).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any()); @@ -344,16 +318,13 @@ public void execute_train() { Map args = new HashMap<>(); args.put(ACTION, TRAIN); args.put(ALGORITHM, KMEANS); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.SAMPLE_ALGO) - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build(); machineLearningNodeClient.run(mlInput, args, trainingActionListener); verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any()); verify(trainingActionListener).onResponse(argumentCaptor.capture()); - assertEquals(modelId, ((MLTrainingOutput)argumentCaptor.getValue()).getModelId()); - assertEquals(status, ((MLTrainingOutput)argumentCaptor.getValue()).getStatus()); + assertEquals(modelId, ((MLTrainingOutput) argumentCaptor.getValue()).getModelId()); + assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus()); } @Test @@ -423,28 +394,24 @@ public void execute_trainandpredict_fit_rcf() { private void execute_trainandpredict(Map args) { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLPredictionOutput predictionOutput = MLPredictionOutput.builder() - .status(MLTaskState.COMPLETED.name()) - .predictionResult(output) - .taskId("taskId") - .build(); - actionListener.onResponse(MLTaskResponse.builder() - .output(predictionOutput) - .build()); + MLPredictionOutput predictionOutput = MLPredictionOutput + .builder() + .status(MLTaskState.COMPLETED.name()) + .predictionResult(output) + .taskId("taskId") + .build(); + actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build()); return null; }).when(client).execute(eq(MLTrainAndPredictionTaskAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLOutput.class); - MLInput mlInput = MLInput.builder() - .algorithm(FunctionName.SAMPLE_ALGO) - .inputDataset(input) - .build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SAMPLE_ALGO).inputDataset(input).build(); machineLearningNodeClient.run(mlInput, args, trainingActionListener); verify(client).execute(eq(MLTrainAndPredictionTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any()); verify(trainingActionListener).onResponse(argumentCaptor.capture()); - assertEquals(MLTaskState.COMPLETED.name(), ((MLPredictionOutput)argumentCaptor.getValue()).getStatus()); - assertEquals(output, ((MLPredictionOutput)argumentCaptor.getValue()).getPredictionResult()); + assertEquals(MLTaskState.COMPLETED.name(), ((MLPredictionOutput) argumentCaptor.getValue()).getStatus()); + assertEquals(output, ((MLPredictionOutput) argumentCaptor.getValue()).getPredictionResult()); } @Test @@ -452,14 +419,8 @@ public void getModel() { String modelContent = "test content"; doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.KMEANS) - .name("test") - .content(modelContent) - .build(); - MLModelGetResponse output = MLModelGetResponse.builder() - .mlModel(mlModel) - .build(); + MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build(); + MLModelGetResponse output = MLModelGetResponse.builder().mlModel(mlModel).build(); actionListener.onResponse(output); return null; }).when(client).execute(eq(MLModelGetAction.INSTANCE), any(), any()); @@ -497,11 +458,7 @@ public void searchModel() { String modelContent = "test content"; doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.KMEANS) - .name("test") - .content(modelContent) - .build(); + MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build(); SearchResponse output = createSearchResponse(mlModel); actionListener.onResponse(output); return null; @@ -517,20 +474,48 @@ public void searchModel() { assertEquals(modelContent, source.get(MLModel.MODEL_CONTENT_FIELD)); } + @Test + public void registerModelGroup() { + + String modelGroupId = "modeGroupId"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelGroupResponse.class); + + List backendRoles = Arrays.asList("IT", "HR"); + + MLRegisterModelGroupInput mlRegisterModelGroupInput = MLRegisterModelGroupInput + .builder() + .name("test") + .description("description") + .backendRoles(backendRoles) + .modelAccessMode(AccessMode.from("public")) + .isAddAllBackendRoles(false) + .build(); + + machineLearningNodeClient.registerModelGroup(mlRegisterModelGroupInput, registerModelGroupResponseActionListener); + + verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any()); + verify(registerModelGroupResponseActionListener).onResponse(argumentCaptor.capture()); + assertEquals(modelGroupId, (argumentCaptor.getValue().getModelGroupId())); + assertEquals(status, (argumentCaptor.getValue().getStatus())); + } + @Test public void getTask() { String taskId = "taskId"; String modelId = "modelId"; doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLTask mlTask = MLTask.builder() - .taskId(taskId) - .modelId(modelId) - .functionName(FunctionName.KMEANS) - .build(); - MLTaskGetResponse output = MLTaskGetResponse.builder() - .mlTask(mlTask) - .build(); + MLTask mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build(); + MLTaskGetResponse output = MLTaskGetResponse.builder().mlTask(mlTask).build(); actionListener.onResponse(output); return null; }).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any()); @@ -570,11 +555,7 @@ public void searchTask() { String modelId = "modelId"; doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); - MLTask mlTask = MLTask.builder() - .taskId(taskId) - .modelId(modelId) - .functionName(FunctionName.KMEANS) - .build(); + MLTask mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build(); SearchResponse output = createSearchResponse(mlTask); actionListener.onResponse(output); return null; @@ -605,23 +586,25 @@ public void register() { }).when(client).execute(eq(MLRegisterModelAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - MLRegisterModelInput mlInput = MLRegisterModelInput.builder() - .functionName(functionName) - .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput mlInput = MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); machineLearningNodeClient.register(mlInput, registerModelActionListener); verify(client).execute(eq(MLRegisterModelAction.INSTANCE), isA(MLRegisterModelRequest.class), any()); @@ -655,7 +638,6 @@ public void deploy() { @Test public void createConnector() { - String connectorId = "connectorId"; doAnswer(invocation -> { @@ -671,19 +653,20 @@ public void createConnector() { Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); List backendRoles = Arrays.asList("IT", "HR"); - MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.builder() - .name("test") - .description("description") - .version("testModelVersion") - .protocol("testProtocol") - .parameters(params) - .credential(credentials) - .actions(null) - .backendRoles(backendRoles) - .addAllBackendRoles(false) - .access(AccessMode.from("private")) - .dryRun(false) - .build(); + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput + .builder() + .name("test") + .description("description") + .version("testModelVersion") + .protocol("testProtocol") + .parameters(params) + .credential(credentials) + .actions(null) + .backendRoles(backendRoles) + .addAllBackendRoles(false) + .access(AccessMode.from("private")) + .dryRun(false) + .build(); machineLearningNodeClient.createConnector(mlCreateConnectorInput, createConnectorActionListener); @@ -693,8 +676,6 @@ public void createConnector() { } - - private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); @@ -702,22 +683,22 @@ private SearchResponse createSearchResponse(ToXContentObject o) throws IOExcepti hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); return new SearchResponse( - new InternalSearchResponse( - new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f), - InternalAggregations.EMPTY, - new Suggest(Collections.emptyList()), - new SearchProfileShardResults(Collections.emptyMap()), - false, - false, - 1 - ), - "", - 5, - 5, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY ); } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 9e2fbb7133..070c13b204 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -11,6 +11,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import lombok.Builder; import lombok.Getter; import lombok.Setter; @@ -57,10 +58,7 @@ public MLModelGroup(String name, String description, int latestVersion, String modelGroupId, Instant createdTime, Instant lastUpdatedTime) { - if (name == null) { - throw new IllegalArgumentException("model group name is null"); - } - this.name = name; + this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; this.latestVersion = latestVersion; this.backendRoles = backendRoles; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java index 4595a16d77..c686d4bef5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Objects; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -39,10 +40,7 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{ @Builder(toBuilder = true) public MLRegisterModelGroupInput(String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { - if (name == null) { - throw new IllegalArgumentException("model group name is null"); - } - this.name = name; + this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; this.backendRoles = backendRoles; this.modelAccessMode = modelAccessMode; diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java index da8048b1cc..71f7f46cf2 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java @@ -30,8 +30,8 @@ public class MLModelGroupTest { @Test public void toXContent_NullName() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("model group name is null"); + exceptionRule.expect(NullPointerException.class); + exceptionRule.expectMessage("model group name must not be null"); MLModelGroup.builder().build(); } From d6ac640a083734b926ffdc400329c7b20a00cfe9 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 3 Nov 2023 19:30:50 +0000 Subject: [PATCH 4/5] Manual backport of 1560 Signed-off-by: Joshua Palis --- .../ml/client/MachineLearningNodeClient.java | 17 +++++++++------ .../register/MLRegisterModelResponse.java | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 6828c111a2..c828feb92c 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -221,12 +221,7 @@ public void searchTask(SearchRequest searchRequest, ActionListener listener) { MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput); - client - .execute( - MLRegisterModelAction.INSTANCE, - registerRequest, - ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) - ); + client.execute(MLRegisterModelAction.INSTANCE, registerRequest, getMLRegisterModelResponseActionListener(listener)); } @Override @@ -293,6 +288,16 @@ private ActionListener getMlPredictionTaskResponseActionListener return actionListener; } + private ActionListener getMLRegisterModelResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, res -> { + MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res); + return registerModelResponse; + }); + return actionListener; + } + private ActionListener wrapActionListener( final ActionListener listener, final Function recreate diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java index c7baa9b3a6..243fbe19c8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java @@ -7,13 +7,18 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject { @@ -61,4 +66,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } + + public static MLRegisterModelResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterModelResponse) { + return (MLRegisterModelResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterModelResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelResponse", e); + } + } } From d8b35f08e6ddb1ec110ace6de4cc464c5b833d24 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 3 Nov 2023 19:39:04 +0000 Subject: [PATCH 5/5] manual backport of 1580 Signed-off-by: Joshua Palis --- .../ml/client/MachineLearningNodeClient.java | 44 +++++++++++++++---- .../connector/MLCreateConnectorResponse.java | 22 ++++++++++ .../deploy/MLDeployModelResponse.java | 22 ++++++++++ .../MLRegisterModelGroupResponse.java | 22 ++++++++++ 4 files changed, 102 insertions(+), 8 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index c828feb92c..7bdf86d109 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -187,7 +187,12 @@ public void registerModelGroup( ActionListener listener ) { MLRegisterModelGroupRequest mlRegisterModelGroupRequest = new MLRegisterModelGroupRequest(mlRegisterModelGroupInput); - client.execute(MLRegisterModelGroupAction.INSTANCE, mlRegisterModelGroupRequest, listener); + client + .execute( + MLRegisterModelGroupAction.INSTANCE, + mlRegisterModelGroupRequest, + getMlRegisterModelGroupResponseActionListener(listener) + ); } @Override @@ -227,18 +232,13 @@ public void register(MLRegisterModelInput mlInput, ActionListener listener) { MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, false); - client - .execute( - MLDeployModelAction.INSTANCE, - deployModelRequest, - ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) - ); + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener)); } @Override public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener listener) { MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput); - client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, listener); + client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener)); } @Override @@ -277,6 +277,34 @@ private ActionListener getMlGetToolResponseActionListener(Act return actionListener; } + private ActionListener getMlDeployModelResponseActionListener(ActionListener listener) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLDeployModelResponse deployModelResponse = MLDeployModelResponse.fromActionResponse(response); + return deployModelResponse; + }); + return actionListener; + } + + private ActionListener getMlCreateConnectorResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLCreateConnectorResponse createConnectorResponse = MLCreateConnectorResponse.fromActionResponse(response); + return createConnectorResponse; + }); + return actionListener; + } + + private ActionListener getMlRegisterModelGroupResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, response -> { + MLRegisterModelGroupResponse registerModelGroupResponse = MLRegisterModelGroupResponse.fromActionResponse(response); + return registerModelGroupResponse; + }); + return actionListener; + } + private ActionListener getMlPredictionTaskResponseActionListener(ActionListener listener) { ActionListener internalListener = ActionListener.wrap(predictionResponse -> { listener.onResponse(predictionResponse.getOutput()); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java index bf7b78e775..68ce877baa 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLCreateConnectorResponse extends ActionResponse implements ToXContentObject { @@ -42,4 +47,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLCreateConnectorResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateConnectorResponse) { + return (MLCreateConnectorResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateConnectorResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLCreateConnectorResponse", e); + } + + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java index ddf0104c9e..ca35af68f0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java @@ -7,6 +7,8 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; @@ -14,7 +16,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTaskType; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLDeployModelResponse extends ActionResponse implements ToXContentObject { @@ -57,4 +62,21 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } + + public static MLDeployModelResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLDeployModelResponse) { + return (MLDeployModelResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeployModelResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLDeployModelResponse", e); + } + + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java index 2b70ede72f..e4d1db5511 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java @@ -7,12 +7,17 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterModelGroupResponse extends ActionResponse implements ToXContentObject { @@ -49,4 +54,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + public static MLRegisterModelGroupResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterModelGroupResponse) { + return (MLRegisterModelGroupResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterModelGroupResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelGroupResponse", e); + } + + } }