-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
connector transport actions and disable native memory CB
Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
e2c9948
commit 1f824fe
Showing
11 changed files
with
749 additions
and
4 deletions.
There are no files selected for viewing
129 changes: 129 additions & 0 deletions
129
plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.connector; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; | ||
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; | ||
|
||
import org.opensearch.action.ActionListener; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.DocWriteResponse; | ||
import org.opensearch.action.delete.DeleteRequest; | ||
import org.opensearch.action.delete.DeleteResponse; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.common.MLModel; | ||
import org.opensearch.ml.common.exception.MLValidationException; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; | ||
import org.opensearch.ml.helper.ConnectorAccessControlHelper; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class DeleteConnectorTransportAction extends HandledTransportAction<ActionRequest, DeleteResponse> { | ||
|
||
Client client; | ||
NamedXContentRegistry xContentRegistry; | ||
|
||
ConnectorAccessControlHelper connectorAccessControlHelper; | ||
|
||
@Inject | ||
public DeleteConnectorTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
ConnectorAccessControlHelper connectorAccessControlHelper | ||
) { | ||
super(MLConnectorDeleteAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.connectorAccessControlHelper = connectorAccessControlHelper; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) { | ||
MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.fromActionRequest(request); | ||
String connectorId = mlConnectorDeleteRequest.getConnectorId(); | ||
DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId); | ||
connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(x -> { | ||
if (Boolean.TRUE.equals(x)) { | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); | ||
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); | ||
sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); | ||
searchRequest.source(sourceBuilder); | ||
client.search(searchRequest, ActionListener.wrap(searchResponse -> { | ||
SearchHit[] searchHits = searchResponse.getHits().getHits(); | ||
if (searchHits.length == 0) { | ||
deleteConnector(deleteRequest, connectorId, actionListener); | ||
} else { | ||
log | ||
.error( | ||
searchHits.length + " models are still using this connector, please delete or update the models first!" | ||
); | ||
actionListener | ||
.onFailure( | ||
new MLValidationException( | ||
searchHits.length | ||
+ " models are still using this connector, please delete or update the models first!" | ||
) | ||
); | ||
} | ||
}, e -> { | ||
log.error("Failed to delete ML connector: " + connectorId, e); | ||
actionListener.onFailure(e); | ||
})); | ||
} catch (Exception e) { | ||
log.error(e.getMessage(), e); | ||
actionListener.onFailure(e); | ||
} | ||
} else { | ||
actionListener.onFailure(new MLValidationException("You are not allowed to delete this connector")); | ||
} | ||
}, e -> { | ||
log.error("Failed to delete ML connector: " + connectorId, e); | ||
actionListener.onFailure(e); | ||
})); | ||
} | ||
|
||
private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener<DeleteResponse> actionListener) { | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
client.delete(deleteRequest, new ActionListener<>() { | ||
@Override | ||
public void onResponse(DeleteResponse deleteResponse) { | ||
if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { | ||
log.info("Connector id:{} not found", connectorId); | ||
actionListener.onResponse(deleteResponse); | ||
return; | ||
} | ||
log.info("Completed Delete Connector Request, connector id:{} deleted", connectorId); | ||
actionListener.onResponse(deleteResponse); | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
log.error("Failed to delete ML connector: " + connectorId, e); | ||
actionListener.onFailure(e); | ||
} | ||
}); | ||
} catch (Exception e) { | ||
log.error("Failed to delete ML connector: " + connectorId, e); | ||
actionListener.onFailure(e); | ||
} | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.connector; | ||
|
||
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; | ||
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; | ||
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; | ||
|
||
import org.opensearch.action.ActionListener; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.get.GetRequest; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.commons.authuser.User; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.index.IndexNotFoundException; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.exception.MLResourceNotFoundException; | ||
import org.opensearch.ml.common.exception.MLValidationException; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; | ||
import org.opensearch.ml.helper.ConnectorAccessControlHelper; | ||
import org.opensearch.ml.utils.RestActionUtils; | ||
import org.opensearch.search.fetch.subphase.FetchSourceContext; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.experimental.FieldDefaults; | ||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
public class GetConnectorTransportAction extends HandledTransportAction<ActionRequest, MLConnectorGetResponse> { | ||
|
||
Client client; | ||
NamedXContentRegistry xContentRegistry; | ||
|
||
ConnectorAccessControlHelper connectorAccessControlHelper; | ||
|
||
@Inject | ||
public GetConnectorTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
NamedXContentRegistry xContentRegistry, | ||
ConnectorAccessControlHelper connectorAccessControlHelper | ||
) { | ||
super(MLConnectorGetAction.NAME, transportService, actionFilters, MLConnectorGetRequest::new); | ||
this.client = client; | ||
this.xContentRegistry = xContentRegistry; | ||
this.connectorAccessControlHelper = connectorAccessControlHelper; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<MLConnectorGetResponse> actionListener) { | ||
MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.fromActionRequest(request); | ||
String connectorId = mlConnectorGetRequest.getConnectorId(); | ||
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlConnectorGetRequest.isReturnContent()); | ||
GetRequest getRequest = new GetRequest(ML_CONNECTOR_INDEX).id(connectorId).fetchSourceContext(fetchSourceContext); | ||
User user = RestActionUtils.getUserContext(client); | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { | ||
log.debug("Completed Get Connector Request, id:{}", connectorId); | ||
|
||
if (r != null && r.isExists()) { | ||
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); | ||
Connector mlConnector = Connector.createConnector(parser); | ||
mlConnector.removeCredential(); | ||
if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { | ||
actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); | ||
} else { | ||
actionListener.onFailure(new MLValidationException("You don't have permission to access this connector")); | ||
} | ||
} catch (Exception e) { | ||
log.error("Failed to parse ml connector" + r.getId(), e); | ||
actionListener.onFailure(e); | ||
} | ||
} else { | ||
actionListener | ||
.onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId)); | ||
} | ||
}, e -> { | ||
if (e instanceof IndexNotFoundException) { | ||
actionListener.onFailure(new MLResourceNotFoundException("Fail to find connector")); | ||
} else { | ||
log.error("Failed to get ML connector " + connectorId, e); | ||
actionListener.onFailure(e); | ||
} | ||
}), context::restore)); | ||
} catch (Exception e) { | ||
log.error("Failed to get ML connector " + connectorId, e); | ||
actionListener.onFailure(e); | ||
} | ||
|
||
} | ||
} |
82 changes: 82 additions & 0 deletions
82
plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.connector; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.stream.Collectors; | ||
|
||
import org.opensearch.action.ActionListener; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.commons.authuser.User; | ||
import org.opensearch.ml.common.CommonValue; | ||
import org.opensearch.ml.common.connector.HttpConnector; | ||
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; | ||
import org.opensearch.ml.helper.ConnectorAccessControlHelper; | ||
import org.opensearch.ml.utils.RestActionUtils; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.search.fetch.subphase.FetchSourceContext; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class SearchConnectorTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> { | ||
|
||
private final Client client; | ||
|
||
private final ConnectorAccessControlHelper connectorAccessControlHelper; | ||
|
||
@Inject | ||
public SearchConnectorTransportAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
ConnectorAccessControlHelper connectorAccessControlHelper | ||
) { | ||
super(MLConnectorSearchAction.NAME, transportService, actionFilters, SearchRequest::new); | ||
this.client = client; | ||
this.connectorAccessControlHelper = connectorAccessControlHelper; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) { | ||
request.indices(CommonValue.ML_CONNECTOR_INDEX); | ||
search(request, actionListener); | ||
} | ||
|
||
private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) { | ||
User user = RestActionUtils.getUserContext(client); | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
FetchSourceContext fetchSourceContext = request.source().fetchSource(); | ||
List<String> excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList()); | ||
excludes.add(HttpConnector.CREDENTIAL_FIELD); | ||
FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext( | ||
fetchSourceContext.fetchSource(), | ||
fetchSourceContext.includes(), | ||
excludes.toArray(new String[0]) | ||
); | ||
request.source().fetchSource(rebuiltFetchSourceContext); | ||
if (connectorAccessControlHelper.skipConnectorAccessControl(user)) { | ||
client.search(request, actionListener); | ||
} else { | ||
SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source()); | ||
request.source(sourceBuilder); | ||
client.search(request, actionListener); | ||
} | ||
} catch (Exception e) { | ||
log.error(e.getMessage(), e); | ||
actionListener.onFailure(e); | ||
} | ||
} | ||
} |
Oops, something went wrong.