Skip to content

Commit

Permalink
connector transport actions and disable native memory CB
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Jul 10, 2023
1 parent 11b72bc commit 2d4ab84
Show file tree
Hide file tree
Showing 10 changed files with 551 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.connector;

import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class DeleteConnectorTransportAction extends HandledTransportAction<ActionRequest, DeleteResponse> {

Client client;
NamedXContentRegistry xContentRegistry;

ConnectorAccessControlHelper connectorAccessControlHelper;

@Inject
public DeleteConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry,
ConnectorAccessControlHelper connectorAccessControlHelper
) {
super(MLConnectorDeleteAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.connectorAccessControlHelper = connectorAccessControlHelper;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.fromActionRequest(request);
String connectorId = mlConnectorDeleteRequest.getConnectorId();
DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId);
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(x -> {
if (Boolean.TRUE.equals(x)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId));
searchRequest.source(sourceBuilder);
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
deleteConnector(deleteRequest, connectorId, actionListener);
} else {
log
.error(
searchHits.length + " models are still using this connector, please delete or update the models first!"
);
actionListener
.onFailure(
new MLValidationException(
searchHits.length
+ " models are still using this connector, please delete or update the models first!"
)
);
}
}, e -> {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}));
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLValidationException("You are not allowed to delete this connector"));
}
}, e -> {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}));
}

private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener<DeleteResponse> actionListener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) {
log.info("Connector id:{} not found", connectorId);
actionListener.onResponse(deleteResponse);
return;
}
log.info("Completed Delete Connector Request, connector id:{} deleted", connectorId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}
});
} catch (Exception e) {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.connector;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLConnectorGetAction;
import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest;
import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;

@Log4j2
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GetConnectorTransportAction extends HandledTransportAction<ActionRequest, MLConnectorGetResponse> {

Client client;
NamedXContentRegistry xContentRegistry;

ConnectorAccessControlHelper connectorAccessControlHelper;

@Inject
public GetConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry,
ConnectorAccessControlHelper connectorAccessControlHelper
) {
super(MLConnectorGetAction.NAME, transportService, actionFilters, MLConnectorGetRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.connectorAccessControlHelper = connectorAccessControlHelper;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLConnectorGetResponse> actionListener) {
MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.fromActionRequest(request);
String connectorId = mlConnectorGetRequest.getConnectorId();
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlConnectorGetRequest.isReturnContent());
GetRequest getRequest = new GetRequest(ML_CONNECTOR_INDEX).id(connectorId).fetchSourceContext(fetchSourceContext);
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
log.debug("Completed Get Connector Request, id:{}", connectorId);

if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Connector mlConnector = Connector.createConnector(parser);
mlConnector.removeCredential();
if (connectorAccessControlHelper.hasPermission(user, mlConnector)) {
actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build());
} else {
actionListener.onFailure(new MLValidationException("You don't have permission to access this connector"));
}
} catch (Exception e) {
log.error("Failed to parse ml connector" + r.getId(), e);
actionListener.onFailure(e);
}
} else {
actionListener
.onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find connector"));
} else {
log.error("Failed to get ML connector " + connectorId, e);
actionListener.onFailure(e);
}
}), context::restore));
} catch (Exception e) {
log.error("Failed to get ML connector " + connectorId, e);
actionListener.onFailure(e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.connector;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class SearchConnectorTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {

private final Client client;

private final ConnectorAccessControlHelper connectorAccessControlHelper;

@Inject
public SearchConnectorTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
ConnectorAccessControlHelper connectorAccessControlHelper
) {
super(MLConnectorSearchAction.NAME, transportService, actionFilters, SearchRequest::new);
this.client = client;
this.connectorAccessControlHelper = connectorAccessControlHelper;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
request.indices(CommonValue.ML_CONNECTOR_INDEX);
search(request, actionListener);
}

private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
FetchSourceContext fetchSourceContext = request.source().fetchSource();
List<String> excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList());
excludes.add(HttpConnector.CREDENTIAL_FIELD);
FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext(
fetchSourceContext.fetchSource(),
fetchSourceContext.includes(),
excludes.toArray(new String[0])
);
request.source().fetchSource(rebuiltFetchSourceContext);
if (connectorAccessControlHelper.skipConnectorAccessControl(user)) {
client.search(request, actionListener);
} else {
SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source());
request.source(sourceBuilder);
client.search(request, actionListener);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
}
}
}
Loading

0 comments on commit 2d4ab84

Please sign in to comment.