Skip to content

Commit

Permalink
[Backport 2.x] Added Create Connector Step (#113)
Browse files Browse the repository at this point in the history
Added Create Connector Step (#107)

* Added initial implementation of create connector



* Added test for create connector



* Added more tests and updated MLClient initialization



* Addressed PR comments



* CompletedFuture exceptionally if fields are not present



---------


(cherry picked from commit 23b2f15)

Signed-off-by: Owais Kazi <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent 5762315 commit 2ff9ed6
Show file tree
Hide file tree
Showing 13 changed files with 399 additions and 97 deletions.
4 changes: 3 additions & 1 deletion src/main/java/demo/Demo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
Expand Down Expand Up @@ -59,7 +60,8 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient);

ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/demo/TemplateParseDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
Expand Down Expand Up @@ -55,7 +56,8 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient);
ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.flowframework.workflow.CreateIndexStep;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
Expand Down Expand Up @@ -76,7 +77,8 @@ public Collection<Object> createComponents(
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client);
MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

// TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep
Expand Down
34 changes: 0 additions & 34 deletions src/main/java/org/opensearch/flowframework/client/MLClient.java

This file was deleted.

14 changes: 12 additions & 2 deletions src/main/java/org/opensearch/flowframework/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ private CommonValue() {}
public static final String MODEL_ID = "model_id";
/** Function Name field */
public static final String FUNCTION_NAME = "function_name";
/** Model Name field */
public static final String MODEL_NAME = "name";
/** Name field */
public static final String NAME_FIELD = "name";
/** Model Version field */
public static final String MODEL_VERSION = "model_version";
/** Model Group Id field */
Expand All @@ -62,4 +62,14 @@ private CommonValue() {}
public static final String MODEL_FORMAT = "model_format";
/** Model config field */
public static final String MODEL_CONFIG = "model_config";
/** Version field */
public static final String VERSION_FIELD = "version";
/** Connector protocol field */
public static final String PROTOCOL_FIELD = "protocol";
/** Connector parameters field */
public static final String PARAMETERS_FIELD = "parameters";
/** Connector credentials field */
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;

/**
* Step to create a connector for a remote model
*/
public class CreateConnectorStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "create_connector";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public CreateConnectorStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {
CompletableFuture<WorkflowData> createConnectorFuture = new CompletableFuture<>();

ActionListener<MLCreateConnectorResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
logger.info("Created connector successfully");
// TODO Add the response to Global Context
createConnectorFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId())))
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to create connector");
createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String name = null;
String description = null;
String version = null;
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case VERSION_FIELD:
version = (String) content.get(VERSION_FIELD);
break;
case PROTOCOL_FIELD:
protocol = (String) content.get(PROTOCOL_FIELD);
break;
case PARAMETERS_FIELD:
parameters = getParameterMap((Map<String, String>) content.get(PARAMETERS_FIELD));
break;
case CREDENTIALS_FIELD:
credentials = (Map<String, String>) content.get(CREDENTIALS_FIELD);
break;
case ACTIONS_FIELD:
actions = (List<ConnectorAction>) content.get(ACTIONS_FIELD);
break;
}

}
}

if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) {
MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder()
.name(name)
.description(description)
.version(version)
.protocol(protocol)
.parameters(parameters)
.credential(credentials)
.actions(actions)
.build();

mlClient.createConnector(mlInput, actionListener);
} else {
createConnectorFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
}

return createConnectorFuture;
}

@Override
public String getName() {
return NAME;
}

private static Map<String, String> getParameterMap(Map<String, String> params) {

Map<String, String> parameters = new HashMap<>();
for (String key : params.keySet()) {
String value = params.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
parameters.put(key, value);
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

Expand All @@ -28,24 +29,22 @@
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private Client client;
private MachineLearningNodeClient mlClient;
static final String NAME = "deploy_model";

/**
* Instantiate this class
* @param client client to instantiate MLClient
* @param mlClient client to instantiate MLClient
*/
public DeployModelStep(Client client) {
this.client = client;
public DeployModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> deployModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
Expand All @@ -57,8 +56,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) {

@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
deployModelFuture.completeExceptionally(e);
logger.error("Failed to deploy model");
deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

Expand All @@ -70,7 +69,13 @@ public void onFailure(Exception e) {
break;
}
}
machineLearningNodeClient.deploy(modelId, actionListener);

if (modelId != null) {
mlClient.deploy(modelId, actionListener);
} else {
deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST));
}

return deployModelFuture;
}

Expand Down
Loading

0 comments on commit 2ff9ed6

Please sign in to comment.