Skip to content

Commit

Permalink
update connector API
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Aug 21, 2023
1 parent 47bf678 commit 9bec7c9
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.opensearch.action.ActionType;
import org.opensearch.action.update.UpdateResponse;

public class MLUpdateConnectorAction extends ActionType<UpdateResponse> {
public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction();
public static final String NAME = "cluster:admin/opensearch/ml/connectors/update";

private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
public class MLUpdateConnectorRequest extends ActionRequest {
String connectorId;
Map<String, Object> updateContent;

@Builder
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) {
this.connectorId = connectorId;
this.updateContent = updateContent;
}

public MLUpdateConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.updateContent = in.readMap();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
out.writeMap(this.getUpdateContent());
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.connectorId == null) {
exception = addValidationError("ML connector id can't be null", exception);
}

return exception;
}

public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException {
Map<String, Object> dataAsMap = null;
dataAsMap = parser.map();

return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build();
}

public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLUpdateConnectorRequest) {
return (MLUpdateConnectorRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUpdateConnectorRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLUpdateConnectorRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.connector;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

@Log4j2
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class UpdateConnectorTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
NamedXContentRegistry xContentRegistry;

ConnectorAccessControlHelper connectorAccessControlHelper;

@Inject
public UpdateConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry,
ConnectorAccessControlHelper connectorAccessControlHelper
) {
super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.connectorAccessControlHelper = connectorAccessControlHelper;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request);
String connectorId = mlUpdateConnectorAction.getConnectorId();
UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId);
updateRequest.doc(mlUpdateConnectorAction.getUpdateContent());
updateRequest.docAsUpsert(true);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> {
if (Boolean.TRUE.equals(hasPermission)) {
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context));
} else {
listener
.onFailure(
new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId)
);
}
}, exception -> {
log.error("You don't have permission to update the connector for connector id: " + connectorId, exception);
listener.onFailure(exception);
}));
} catch (Exception e) {
log.error("Failed to update ML connector " + connectorId, e);
listener.onFailure(e);
}
}

private ActionListener<UpdateResponse> getUpdateResponseListener(
String connectorId,
ActionListener<UpdateResponse> actionListener,
ThreadContext.StoredContext context
) {
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> {
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
log.info("Connector id:{} failed update", connectorId);
actionListener.onResponse(updateResponse);
return;
}
log.info("Completed Update Connector Request, connector id:{} updated", connectorId);
actionListener.onResponse(updateResponse);
}, exception -> {
log.error("Failed to update ML connector: " + connectorId, exception);
actionListener.onFailure(exception);
}), context::restore);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.ml.action.connector.GetConnectorTransportAction;
import org.opensearch.ml.action.connector.SearchConnectorTransportAction;
import org.opensearch.ml.action.connector.TransportCreateConnectorAction;
import org.opensearch.ml.action.connector.UpdateConnectorTransportAction;
import org.opensearch.ml.action.deploy.TransportDeployModelAction;
import org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction;
import org.opensearch.ml.action.execute.TransportExecuteTaskAction;
Expand Down Expand Up @@ -89,6 +90,7 @@
import org.opensearch.ml.common.transport.connector.MLConnectorGetAction;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
Expand Down Expand Up @@ -149,6 +151,7 @@
import org.opensearch.ml.rest.RestMLTrainAndPredictAction;
import org.opensearch.ml.rest.RestMLTrainingAction;
import org.opensearch.ml.rest.RestMLUndeployModelAction;
import org.opensearch.ml.rest.RestMLUpdateConnectorAction;
import org.opensearch.ml.rest.RestMLUpdateModelGroupAction;
import org.opensearch.ml.rest.RestMLUploadModelChunkAction;
import org.opensearch.ml.settings.MLCommonsSettings;
Expand Down Expand Up @@ -256,7 +259,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class),
new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class),
new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class),
new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class)
new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class),
new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class)
);
}

Expand Down Expand Up @@ -493,6 +497,7 @@ public List<RestHandler> getRestHandlers(
RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction();
RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction();
RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction();
RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting);
return ImmutableList
.of(
restMLStatsAction,
Expand All @@ -519,7 +524,8 @@ public List<RestHandler> getRestHandlers(
restMLCreateConnectorAction,
restMLGetConnectorAction,
restMLDeleteConnectorAction,
restMLSearchConnectorAction
restMLSearchConnectorAction,
restMLUpdateConnectorAction
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

public class RestMLUpdateConnectorAction extends BaseRestHandler {
private static final String ML_UPDATE_CONNECTOR_ACTION = "ml_update_connector_action";
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

public RestMLUpdateConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
return ML_UPDATE_CONNECTOR_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList
.of(
new Route(
RestRequest.Method.POST,
String.format(Locale.ROOT, "%s/connectors/_update/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID)
)
);
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
MLUpdateConnectorRequest mlUpdateConnectorRequest = getRequest(request);
return restChannel -> client
.execute(MLUpdateConnectorAction.INSTANCE, mlUpdateConnectorRequest, new RestToXContentListener<>(restChannel));
}

@VisibleForTesting
private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}

if (!request.hasContent()) {
throw new IOException("Update Connector request has empty body");
}

String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID);
if (Strings.isBlank(connectorId)) {
throw new IOException("Update Connector request has no connector Id");
}

XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);

return MLUpdateConnectorRequest.parse(parser, connectorId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.*;

import org.junit.Before;
import org.junit.Ignore;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -34,6 +37,7 @@
import org.opensearch.threadpool.ThreadPool;

import com.google.gson.Gson;
import com.google.gson.JsonParser;

public class RestMLDeployModelActionTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -129,6 +133,14 @@ private RestRequest getRestRequest() {
.withParams(params)
.withContent(new BytesArray(requestContent), XContentType.JSON)
.build();
UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, "12222");
updateRequest.doc(model);

UpdateRequest updateRequest1 = new UpdateRequest(ML_CONNECTOR_INDEX, "12222");
updateRequest.doc(gson.fromJson(JsonParser.parseString(requestContent), Map.class));

System.out.println(updateRequest);
System.out.println(updateRequest1);
return request;
}
}

0 comments on commit 9bec7c9

Please sign in to comment.