-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
47bf678
commit 9bec7c9
Showing
6 changed files
with
300 additions
and
2 deletions.
There are no files selected for viewing
16 changes: 16 additions & 0 deletions
16
...n/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.connector; | ||
|
||
import org.opensearch.action.ActionType; | ||
import org.opensearch.action.update.UpdateResponse; | ||
|
||
public class MLUpdateConnectorAction extends ActionType<UpdateResponse> { | ||
public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction(); | ||
public static final String NAME = "cluster:admin/opensearch/ml/connectors/update"; | ||
|
||
private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);} | ||
} |
83 changes: 83 additions & 0 deletions
83
.../src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.transport.connector; | ||
|
||
import lombok.Builder; | ||
import lombok.Getter; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.core.common.io.stream.InputStreamStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
|
||
import java.io.ByteArrayInputStream; | ||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.action.ValidateActions.addValidationError; | ||
|
||
@Getter | ||
public class MLUpdateConnectorRequest extends ActionRequest { | ||
String connectorId; | ||
Map<String, Object> updateContent; | ||
|
||
@Builder | ||
public MLUpdateConnectorRequest(String connectorId, Map<String, Object> updateContent) { | ||
this.connectorId = connectorId; | ||
this.updateContent = updateContent; | ||
} | ||
|
||
public MLUpdateConnectorRequest(StreamInput in) throws IOException { | ||
super(in); | ||
this.connectorId = in.readString(); | ||
this.updateContent = in.readMap(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(this.connectorId); | ||
out.writeMap(this.getUpdateContent()); | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
ActionRequestValidationException exception = null; | ||
|
||
if (this.connectorId == null) { | ||
exception = addValidationError("ML connector id can't be null", exception); | ||
} | ||
|
||
return exception; | ||
} | ||
|
||
public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException { | ||
Map<String, Object> dataAsMap = null; | ||
dataAsMap = parser.map(); | ||
|
||
return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build(); | ||
} | ||
|
||
public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) { | ||
if (actionRequest instanceof MLUpdateConnectorRequest) { | ||
return (MLUpdateConnectorRequest) actionRequest; | ||
} | ||
|
||
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); | ||
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { | ||
actionRequest.writeTo(osso); | ||
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { | ||
return new MLUpdateConnectorRequest(input); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException("failed to parse ActionRequest into MLUpdateConnectorRequest", e); | ||
} | ||
} | ||
} |
99 changes: 99 additions & 0 deletions
99
plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.connector; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.DocWriteResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.action.update.UpdateRequest; | ||
import org.opensearch.action.update.UpdateResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; | ||
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; | ||
import org.opensearch.ml.helper.ConnectorAccessControlHelper; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.experimental.FieldDefaults; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
public class UpdateConnectorTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> { | ||
Client client; | ||
NamedXContentRegistry xContentRegistry; | ||
|
||
ConnectorAccessControlHelper connectorAccessControlHelper; | ||
|
||
@Inject | ||
public UpdateConnectorTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
ConnectorAccessControlHelper connectorAccessControlHelper | ||
) { | ||
super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.connectorAccessControlHelper = connectorAccessControlHelper; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) { | ||
MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request); | ||
String connectorId = mlUpdateConnectorAction.getConnectorId(); | ||
UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); | ||
updateRequest.doc(mlUpdateConnectorAction.getUpdateContent()); | ||
updateRequest.docAsUpsert(true); | ||
|
||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> { | ||
if (Boolean.TRUE.equals(hasPermission)) { | ||
client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); | ||
} else { | ||
listener | ||
.onFailure( | ||
new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId) | ||
); | ||
} | ||
}, exception -> { | ||
log.error("You don't have permission to update the connector for connector id: " + connectorId, exception); | ||
listener.onFailure(exception); | ||
})); | ||
} catch (Exception e) { | ||
log.error("Failed to update ML connector " + connectorId, e); | ||
listener.onFailure(e); | ||
} | ||
} | ||
|
||
private ActionListener<UpdateResponse> getUpdateResponseListener( | ||
String connectorId, | ||
ActionListener<UpdateResponse> actionListener, | ||
ThreadContext.StoredContext context | ||
) { | ||
return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { | ||
if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { | ||
log.info("Connector id:{} failed update", connectorId); | ||
actionListener.onResponse(updateResponse); | ||
return; | ||
} | ||
log.info("Completed Update Connector Request, connector id:{} updated", connectorId); | ||
actionListener.onResponse(updateResponse); | ||
}, exception -> { | ||
log.error("Failed to update ML connector: " + connectorId, exception); | ||
actionListener.onFailure(exception); | ||
}), context::restore); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; | ||
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; | ||
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; | ||
import static org.opensearch.ml.utils.RestActionUtils.getParameterId; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Locale; | ||
|
||
import org.apache.logging.log4j.util.Strings; | ||
import org.opensearch.client.node.NodeClient; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; | ||
import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; | ||
import org.opensearch.ml.settings.MLFeatureEnabledSetting; | ||
import org.opensearch.rest.BaseRestHandler; | ||
import org.opensearch.rest.RestRequest; | ||
import org.opensearch.rest.action.RestToXContentListener; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
import com.google.common.collect.ImmutableList; | ||
|
||
public class RestMLUpdateConnectorAction extends BaseRestHandler { | ||
private static final String ML_UPDATE_CONNECTOR_ACTION = "ml_update_connector_action"; | ||
private MLFeatureEnabledSetting mlFeatureEnabledSetting; | ||
|
||
public RestMLUpdateConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { | ||
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return ML_UPDATE_CONNECTOR_ACTION; | ||
} | ||
|
||
@Override | ||
public List<Route> routes() { | ||
return ImmutableList | ||
.of( | ||
new Route( | ||
RestRequest.Method.POST, | ||
String.format(Locale.ROOT, "%s/connectors/_update/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID) | ||
) | ||
); | ||
} | ||
|
||
@Override | ||
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { | ||
MLUpdateConnectorRequest mlUpdateConnectorRequest = getRequest(request); | ||
return restChannel -> client | ||
.execute(MLUpdateConnectorAction.INSTANCE, mlUpdateConnectorRequest, new RestToXContentListener<>(restChannel)); | ||
} | ||
|
||
@VisibleForTesting | ||
private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOException { | ||
if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { | ||
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); | ||
} | ||
|
||
if (!request.hasContent()) { | ||
throw new IOException("Update Connector request has empty body"); | ||
} | ||
|
||
String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); | ||
if (Strings.isBlank(connectorId)) { | ||
throw new IOException("Update Connector request has no connector Id"); | ||
} | ||
|
||
XContentParser parser = request.contentParser(); | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); | ||
|
||
return MLUpdateConnectorRequest.parse(parser, connectorId); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters