From 50ca665a3abaeb0562322da49e11b65e066454ac Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Mon, 10 Jul 2023 12:06:23 -0700 Subject: [PATCH] restful connector actions and UT Signed-off-by: Xun Zhang --- plugin/build.gradle | 3 +- .../ml/plugin/MachineLearningPlugin.java | 28 ++- .../ml/rest/RestMLCreateConnectorAction.java | 67 ++++++ .../ml/rest/RestMLDeleteConnectorAction.java | 53 +++++ .../ml/rest/RestMLGetConnectorAction.java | 65 ++++++ .../ml/rest/RestMLSearchConnectorAction.java | 28 +++ .../opensearch/ml/utils/RestActionUtils.java | 1 + .../GetConnectorTransportActionTests.java | 177 ++++++++++++++++ .../RestMLCreateConnectorActionTests.java | 110 ++++++++++ .../RestMLDeleteConnectorActionTests.java | 106 ++++++++++ .../rest/RestMLGetConnectorActionTests.java | 110 ++++++++++ .../RestMLSearchConnectorActionTests.java | 194 ++++++++++++++++++ .../org/opensearch/ml/utils/TestHelper.java | 45 ++++ 13 files changed, 984 insertions(+), 3 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java diff --git a/plugin/build.gradle b/plugin/build.gradle index c9beb7c943..cdd91577a0 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -292,7 +292,8 @@ List jacocoExclusions = [ 'org.opensearch.ml.action.connector.DeleteConnectorTransportAction', 'org.opensearch.ml.action.connector.DeleteConnectorTransportAction.1', 'org.opensearch.ml.action.connector.TransportCreateConnectorAction', - 'org.opensearch.ml.action.connector.SearchConnectorTransportAction' + 'org.opensearch.ml.action.connector.SearchConnectorTransportAction', + 'org.opensearch.ml.rest.RestMLCreateConnectorAction' ] jacocoTestCoverageVerification { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 9622579bc0..6b525d8695 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -32,6 +32,10 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; +import org.opensearch.ml.action.connector.GetConnectorTransportAction; +import org.opensearch.ml.action.connector.SearchConnectorTransportAction; +import org.opensearch.ml.action.connector.TransportCreateConnectorAction; import org.opensearch.ml.action.deploy.TransportDeployModelAction; import org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction; import org.opensearch.ml.action.execute.TransportExecuteTaskAction; @@ -79,6 +83,10 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; +import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -116,11 +124,14 @@ import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.rest.RestMLCreateConnectorAction; +import org.opensearch.ml.rest.RestMLDeleteConnectorAction; import org.opensearch.ml.rest.RestMLDeleteModelAction; import org.opensearch.ml.rest.RestMLDeleteModelGroupAction; import org.opensearch.ml.rest.RestMLDeleteTaskAction; import org.opensearch.ml.rest.RestMLDeployModelAction; import org.opensearch.ml.rest.RestMLExecuteAction; +import org.opensearch.ml.rest.RestMLGetConnectorAction; import org.opensearch.ml.rest.RestMLGetModelAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLPredictionAction; @@ -128,6 +139,7 @@ import org.opensearch.ml.rest.RestMLRegisterModelAction; import org.opensearch.ml.rest.RestMLRegisterModelGroupAction; import org.opensearch.ml.rest.RestMLRegisterModelMetaAction; +import org.opensearch.ml.rest.RestMLSearchConnectorAction; import org.opensearch.ml.rest.RestMLSearchModelAction; import org.opensearch.ml.rest.RestMLSearchModelGroupAction; import org.opensearch.ml.rest.RestMLSearchTaskAction; @@ -235,7 +247,11 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class), new ActionHandler<>(MLUpdateModelGroupAction.INSTANCE, TransportUpdateModelGroupAction.class), new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class), - new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class) + new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class), + new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class), + new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class), + new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class), + new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class) ); } @@ -453,6 +469,10 @@ public List getRestHandlers( RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); + RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(); + RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); + RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); + RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); return ImmutableList .of( restMLStatsAction, @@ -475,7 +495,11 @@ public List getRestHandlers( restMLCreateModelGroupAction, restMLUpdateModelGroupAction, restMLSearchModelGroupAction, - restMLDeleteModelGroupAction + restMLDeleteModelGroupAction, + restMLCreateConnectorAction, + restMLGetConnectorAction, + restMLDeleteConnectorAction, + restMLSearchConnectorAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java new file mode 100644 index 0000000000..9b0ae00da8 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateConnectorAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLCreateConnectorAction extends BaseRestHandler { + private static final String ML_CREATE_CONNECTOR_ACTION = "ml_create_connector_action"; + + /** + * Constructor * + */ + public RestMLCreateConnectorAction() {} + + @Override + public String getName() { + return ML_CREATE_CONNECTOR_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/connectors/_create", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLCreateConnectorRequest mlCreateConnectorRequest = getRequest(request); + return channel -> client.execute(MLCreateConnectorAction.INSTANCE, mlCreateConnectorRequest, new RestToXContentListener<>(channel)); + } + + /** + * * Creates a MLCreateConnectorRequest from a RestRequest + * @param request + * @return MLCreateConnectorRequest + * @throws IOException + */ + @VisibleForTesting + MLCreateConnectorRequest getRequest(RestRequest request) throws IOException { + if (!request.hasContent()) { + throw new IOException("Create Connector request has empty body"); + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput.parse(parser); + return new MLCreateConnectorRequest(mlCreateConnectorInput); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java new file mode 100644 index 0000000000..532cd26123 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteConnectorAction.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to delete ML Connector. + */ +public class RestMLDeleteConnectorAction extends BaseRestHandler { + private static final String ML_DELETE_CONNECTOR_ACTION = "ml_delete_connector_action"; + + public void RestMLDeleteConnectorAction() {} + + @Override + public String getName() { + return ML_DELETE_CONNECTOR_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.DELETE, String.format(Locale.ROOT, "%s/connectors/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID)) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String connectorId = request.param(PARAMETER_CONNECTOR_ID); + + MLConnectorDeleteRequest mlConnectorDeleteRequest = new MLConnectorDeleteRequest(connectorId); + return channel -> client.execute(MLConnectorDeleteAction.INSTANCE, mlConnectorDeleteRequest, new RestToXContentListener<>(channel)); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java new file mode 100644 index 0000000000..0c1e124e4c --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetConnectorAction.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.RestActionUtils.returnContent; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; +import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetConnectorAction extends BaseRestHandler { + private static final String ML_GET_CONNECTOR_ACTION = "ml_get_connector_action"; + + /** + * Constructor + */ + public RestMLGetConnectorAction() {} + + @Override + public String getName() { + return ML_GET_CONNECTOR_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/connectors/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLConnectorGetRequest mlConnectorGetRequest = getRequest(request); + return channel -> client.execute(MLConnectorGetAction.INSTANCE, mlConnectorGetRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLConnectorGetRequest from a RestRequest + * + * @param request RestRequest + * @return MLConnectorGetRequest + */ + @VisibleForTesting + MLConnectorGetRequest getRequest(RestRequest request) throws IOException { + String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); + boolean returnContent = returnContent(request); + + return new MLConnectorGetRequest(connectorId, returnContent); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java new file mode 100644 index 0000000000..517882a7a3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchConnectorAction.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; + +import com.google.common.collect.ImmutableList; + +public class RestMLSearchConnectorAction extends AbstractMLSearchAction { + private static final String ML_SEARCH_CONNECTOR_ACTION = "ml_search_connector_action"; + private static final String SEARCH_CONNECTOR_PATH = ML_BASE_URI + "/connectors/_search"; + + public RestMLSearchConnectorAction() { + super(ImmutableList.of(SEARCH_CONNECTOR_PATH), ML_CONNECTOR_INDEX, Connector.class, MLConnectorSearchAction.INSTANCE); + } + + @Override + public String getName() { + return ML_SEARCH_CONNECTOR_ACTION; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index bdbb85f9ed..156057bc58 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -43,6 +43,7 @@ public class RestActionUtils { public static final String PARAMETER_MODEL_GROUP_NAME = "model_group_name"; public static final String PARAMETER_MODEL_ID = "model_id"; public static final String PARAMETER_TASK_ID = "task_id"; + public static final String PARAMETER_CONNECTOR_ID = "connector_id"; public static final String PARAMETER_DEPLOY_MODEL = "deploy"; public static final String PARAMETER_VERSION = "version"; public static final String PARAMETER_MODEL_GROUP_ID = "model_group_id"; diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java new file mode 100644 index 0000000000..249b46b4e3 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; +import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetConnectorTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + GetConnectorTransportAction getConnectorTransportAction; + MLConnectorGetRequest mlConnectorGetRequest; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId("connector_id").build(); + Settings settings = Settings.builder().build(); + + getConnectorTransportAction = spy( + new GetConnectorTransportAction(transportService, actionFilters, client, xContentRegistry, connectorAccessControlHelper) + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void testGetConnector_UserHasNodeAccess() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + + GetResponse getResponse = prepareConnector(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + } + + public void testGetConnector_ValidateAccessFailed() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new Exception("Failed to validate access")); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any()); + + GetResponse getResponse = prepareConnector(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); + } + + public void testGetConnector_NullResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find connector with the provided connector id: connector_id", argumentCaptor.getValue().getMessage()); + } + + public void testGetConnector_IndexNotFoundException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("Fail to find model")); + return null; + }).when(client).get(any(), any()); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find connector", argumentCaptor.getValue().getMessage()); + } + + public void testGetConnector_RuntimeException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("errorMessage")); + return null; + }).when(client).get(any(), any()); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + } + + public GetResponse prepareConnector() throws IOException { + HttpConnector httpConnector = HttpConnector.builder().name("test_connector").protocol("http").build(); + + XContentBuilder content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + GetResponse getResponse = new GetResponse(getResult); + return getResponse; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java new file mode 100644 index 0000000000..650bb4a90e --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateConnectorActionTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.TestHelper.getCreateConnectorRestRequest; +import static org.opensearch.ml.utils.TestHelper.verifyParsedCreateConnectorInput; + +import java.io.IOException; +import java.util.List; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLCreateConnectorActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private RestMLCreateConnectorAction restMLCreateConnectorAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLCreateConnectorAction = new RestMLCreateConnectorAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any()); + + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLCreateConnectorAction mlCreateConnectorAction = new RestMLCreateConnectorAction(); + assertNotNull(mlCreateConnectorAction); + } + + public void testGetName() { + String actionName = restMLCreateConnectorAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_create_connector_action", actionName); + } + + public void testRoutes() { + List routes = restMLCreateConnectorAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/connectors/_create", route.getPath()); + } + + public void testGetRequest() throws IOException { + RestRequest request = getCreateConnectorRestRequest(); + MLCreateConnectorRequest mlCreateConnectorRequest = restMLCreateConnectorAction.getRequest(request); + + MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput(); + verifyParsedCreateConnectorInput(mlCreateConnectorInput); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getCreateConnectorRestRequest(); + restMLCreateConnectorAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateConnectorRequest.class); + verify(client, times(1)).execute(eq(MLCreateConnectorAction.INSTANCE), argumentCaptor.capture(), any()); + MLCreateConnectorInput mlCreateConnectorInput = argumentCaptor.getValue().getMlCreateConnectorInput(); + verifyParsedCreateConnectorInput(mlCreateConnectorInput); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java new file mode 100644 index 0000000000..e32b0cef66 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteConnectorActionTests.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLDeleteConnectorActionTests extends OpenSearchTestCase { + + private RestMLDeleteConnectorAction restMLDeleteConnectorAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLConnectorDeleteAction.INSTANCE), any(), any()); + + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLDeleteConnectorAction mlDeleteConnectorAction = new RestMLDeleteConnectorAction(); + assertNotNull(mlDeleteConnectorAction); + } + + public void testGetName() { + String actionName = restMLDeleteConnectorAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_delete_connector_action", actionName); + } + + public void testRoutes() { + List routes = restMLDeleteConnectorAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.DELETE, route.getMethod()); + assertEquals("/_plugins/_ml/connectors/{connector_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLDeleteConnectorAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLConnectorDeleteRequest.class); + verify(client, times(1)).execute(eq(MLConnectorDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + String connectorId = argumentCaptor.getValue().getConnectorId(); + assertEquals(connectorId, "connector_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_CONNECTOR_ID, "connector_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java new file mode 100644 index 0000000000..1b27ec718b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetConnectorActionTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; +import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; +import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLGetConnectorActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException thrown = ExpectedException.none(); + + private RestMLGetConnectorAction restMLGetConnectorAction; + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + @Before + public void setup() { + restMLGetConnectorAction = new RestMLGetConnectorAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLConnectorGetAction.INSTANCE), any(), any()); + + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetConnectorAction mlGetConnectorAction = new RestMLGetConnectorAction(); + assertNotNull(mlGetConnectorAction); + } + + public void testGetName() { + String actionName = restMLGetConnectorAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_connector_action", actionName); + } + + public void testRoutes() { + List routes = restMLGetConnectorAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/connectors/{connector_id}", route.getPath()); + } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetConnectorAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLConnectorGetRequest.class); + verify(client, times(1)).execute(eq(MLConnectorGetAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getConnectorId(); + assertEquals(taskId, "connector_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_CONNECTOR_ID, "connector_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java new file mode 100644 index 0000000000..74f1349f7c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchConnectorActionTests.java @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.utils.TestHelper.getSearchAllRestRequest; + +import java.io.IOException; +import java.util.List; + +import org.apache.lucene.search.TotalHits; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class RestMLSearchConnectorActionTests extends OpenSearchTestCase { + + private RestMLSearchConnectorAction restMLSearchConnectorAction; + + NodeClient client; + private ThreadPool threadPool; + @Mock + RestChannel channel; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + restMLSearchConnectorAction = new RestMLSearchConnectorAction(); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + + doReturn(builder).when(channel).newBuilder(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + + String connectorContent = "{\"name\":\"test_connector\",\"protocol\":\"http\",\"version\":1}"; + SearchHit connector = SearchHit.fromXContent(TestHelper.parser(connectorContent)); + SearchHits hits = new SearchHits(new SearchHit[] { connector }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + hits, + InternalAggregations.EMPTY, + null, + false, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + actionListener.onResponse(searchResponse); + return null; + }).when(client).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLSearchConnectorAction mlSearchConnectorAction = new RestMLSearchConnectorAction(); + assertNotNull(mlSearchConnectorAction); + } + + public void testGetName() { + String actionName = restMLSearchConnectorAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_search_connector_action", actionName); + } + + public void testRoutes() { + List routes = restMLSearchConnectorAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route postRoute = routes.get(0); + assertEquals(RestRequest.Method.POST, postRoute.getMethod()); + assertThat(postRoute.getMethod(), Matchers.either(Matchers.is(RestRequest.Method.POST)).or(Matchers.is(RestRequest.Method.GET))); + assertEquals("/_plugins/_ml/connectors/_search", postRoute.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getSearchAllRestRequest(); + restMLSearchConnectorAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), argumentCaptor.capture(), any()); + verify(channel, times(1)).sendResponse(responseCaptor.capture()); + SearchRequest searchRequest = argumentCaptor.getValue(); + String[] indices = searchRequest.indices(); + assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices); + System.out.println(searchRequest); + assertEquals( + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + searchRequest.source().toString() + ); + RestResponse restResponse = responseCaptor.getValue(); + assertNotEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); + } + + public void testPrepareRequest_timeout() throws Exception { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + + SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + hits, + InternalAggregations.EMPTY, + null, + true, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + actionListener.onResponse(searchResponse); + return null; + }).when(client).execute(eq(MLConnectorSearchAction.INSTANCE), any(), any()); + + RestRequest request = getSearchAllRestRequest(); + restMLSearchConnectorAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(RestResponse.class); + verify(client, times(1)).execute(eq(MLConnectorSearchAction.INSTANCE), argumentCaptor.capture(), any()); + verify(channel, times(1)).sendResponse(responseCaptor.capture()); + SearchRequest searchRequest = argumentCaptor.getValue(); + String[] indices = searchRequest.indices(); + assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices); + assertEquals( + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + searchRequest.source().toString() + ); + RestResponse restResponse = responseCaptor.getValue(); + assertEquals(RestStatus.REQUEST_TIMEOUT, restResponse.status()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index a1b2297b4b..2e8cbb93a6 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -7,6 +7,7 @@ import static org.apache.http.entity.ContentType.APPLICATION_JSON; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; import static org.opensearch.cluster.node.DiscoveryNodeRole.DATA_ROLE; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; @@ -70,6 +71,7 @@ import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.profile.MLProfileInput; import org.opensearch.ml.stats.MLStatsInput; import org.opensearch.rest.RestRequest; @@ -203,6 +205,49 @@ public static RestRequest getKMeansRestRequest() { return request; } + public static RestRequest getCreateConnectorRestRequest() { + final String requestContent = "{\n" + + " \"name\": \"OpenAI Connector\",\n" + + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"text-davinci-003\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"xxxxxxxx\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ],\n" + + " \"access_mode\": \"public\"\n" + + "}"; + RestRequest request = new FakeRestRequest.Builder(getXContentRegistry()) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + public static void verifyParsedCreateConnectorInput(MLCreateConnectorInput mlCreateConnectorInput) { + assertEquals("OpenAI Connector", mlCreateConnectorInput.getName()); + assertEquals("http", mlCreateConnectorInput.getProtocol()); + assertNotNull(mlCreateConnectorInput.getActions()); + assertNotNull(mlCreateConnectorInput.getCredential()); + } + public static RestRequest getStatsRestRequest(MLStatsInput input) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); input.toXContent(builder, ToXContent.EMPTY_PARAMS);