diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java new file mode 100644 index 0000000000..cd49b335f4 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteConnectorTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + ConnectorAccessControlHelper connectorAccessControlHelper; + + @Inject + public DeleteConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ConnectorAccessControlHelper connectorAccessControlHelper + ) { + super(MLConnectorDeleteAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.connectorAccessControlHelper = connectorAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.fromActionRequest(request); + String connectorId = mlConnectorDeleteRequest.getConnectorId(); + DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId); + connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(x -> { + if (Boolean.TRUE.equals(x)) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); + searchRequest.source(sourceBuilder); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + SearchHit[] searchHits = searchResponse.getHits().getHits(); + if (searchHits.length == 0) { + deleteConnector(deleteRequest, connectorId, actionListener); + } else { + log + .error( + searchHits.length + " models are still using this connector, please delete or update the models first!" + ); + actionListener + .onFailure( + new MLValidationException( + searchHits.length + + " models are still using this connector, please delete or update the models first!" + ) + ); + } + }, e -> { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + })); + } catch (Exception e) { + log.error(e.getMessage(), e); + actionListener.onFailure(e); + } + } else { + actionListener.onFailure(new MLValidationException("You are not allowed to delete this connector")); + } + }, e -> { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + })); + } + + private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener actionListener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.delete(deleteRequest, new ActionListener<>() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { + log.info("Connector id:{} not found", connectorId); + actionListener.onResponse(deleteResponse); + return; + } + log.info("Completed Delete Connector Request, connector id:{} deleted", connectorId); + actionListener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + } + }); + } catch (Exception e) { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java new file mode 100644 index 0000000000..59bab2393e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; +import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; +import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetConnectorTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + ConnectorAccessControlHelper connectorAccessControlHelper; + + @Inject + public GetConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ConnectorAccessControlHelper connectorAccessControlHelper + ) { + super(MLConnectorGetAction.NAME, transportService, actionFilters, MLConnectorGetRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.connectorAccessControlHelper = connectorAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.fromActionRequest(request); + String connectorId = mlConnectorGetRequest.getConnectorId(); + FetchSourceContext fetchSourceContext = getFetchSourceContext(mlConnectorGetRequest.isReturnContent()); + GetRequest getRequest = new GetRequest(ML_CONNECTOR_INDEX).id(connectorId).fetchSourceContext(fetchSourceContext); + User user = RestActionUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + log.debug("Completed Get Connector Request, id:{}", connectorId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector mlConnector = Connector.createConnector(parser); + mlConnector.removeCredential(); + if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { + actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); + } else { + actionListener.onFailure(new MLValidationException("You don't have permission to access this connector")); + } + } catch (Exception e) { + log.error("Failed to parse ml connector" + r.getId(), e); + actionListener.onFailure(e); + } + } else { + actionListener + .onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + actionListener.onFailure(new MLResourceNotFoundException("Fail to find connector")); + } else { + log.error("Failed to get ML connector " + connectorId, e); + actionListener.onFailure(e); + } + }), context::restore)); + } catch (Exception e) { + log.error("Failed to get ML connector " + connectorId, e); + actionListener.onFailure(e); + } + + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java new file mode 100644 index 0000000000..f55291e60f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class SearchConnectorTransportAction extends HandledTransportAction { + + private final Client client; + + private final ConnectorAccessControlHelper connectorAccessControlHelper; + + @Inject + public SearchConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ConnectorAccessControlHelper connectorAccessControlHelper + ) { + super(MLConnectorSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + this.client = client; + this.connectorAccessControlHelper = connectorAccessControlHelper; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + request.indices(CommonValue.ML_CONNECTOR_INDEX); + search(request, actionListener); + } + + private void search(SearchRequest request, ActionListener actionListener) { + User user = RestActionUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + FetchSourceContext fetchSourceContext = request.source().fetchSource(); + List excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList()); + excludes.add(HttpConnector.CREDENTIAL_FIELD); + FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext( + fetchSourceContext.fetchSource(), + fetchSourceContext.includes(), + excludes.toArray(new String[0]) + ); + request.source().fetchSource(rebuiltFetchSourceContext); + if (connectorAccessControlHelper.skipConnectorAccessControl(user)) { + client.search(request, actionListener); + } else { + SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source()); + request.source(sourceBuilder); + client.search(request, actionListener); + } + } catch (Exception e) { + log.error(e.getMessage(), e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java new file mode 100644 index 0000000000..aaa07ebedd --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -0,0 +1,203 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; + +import java.util.HashSet; +import java.util.List; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CollectionUtils; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.exceptions.MetaDataException; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class TransportCreateConnectorAction extends HandledTransportAction { + private final MLIndicesHandler mlIndicesHandler; + private final Client client; + private final MLEngine mlEngine; + private final MLModelManager mlModelManager; + private final ConnectorAccessControlHelper connectorAccessControlHelper; + + private volatile List trustedConnectorEndpointsRegex; + + @Inject + public TransportCreateConnectorAction( + TransportService transportService, + ActionFilters actionFilters, + MLIndicesHandler mlIndicesHandler, + Client client, + MLEngine mlEngine, + ConnectorAccessControlHelper connectorAccessControlHelper, + Settings settings, + ClusterService clusterService, + MLModelManager mlModelManager + ) { + super(MLCreateConnectorAction.NAME, transportService, actionFilters, MLCreateConnectorRequest::new); + this.mlIndicesHandler = mlIndicesHandler; + this.client = client; + this.mlEngine = mlEngine; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlModelManager = mlModelManager; + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> trustedConnectorEndpointsRegex = it); + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.fromActionRequest(request); + MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput(); + if (MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME.equals(mlCreateConnectorInput.getName())) { + MLCreateConnectorResponse response = new MLCreateConnectorResponse(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME); + listener.onResponse(response); + return; + } + String connectorName = mlCreateConnectorInput.getName(); + try { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + mlCreateConnectorInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + Connector connector = Connector.createConnector(builder, mlCreateConnectorInput.getProtocol()); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + + User user = RestActionUtils.getUserContext(client); + if (connectorAccessControlHelper.accessControlNotEnabled(user)) { + validateSecurityDisabledOrConnectorAccessControlDisabled(mlCreateConnectorInput); + indexConnector(connector, listener); + } else { + validateRequest4AccessControl(mlCreateConnectorInput, user); + if (Boolean.TRUE.equals(mlCreateConnectorInput.getAddAllBackendRoles())) { + mlCreateConnectorInput.setBackendRoles(user.getBackendRoles()); + } + connector.setBackendRoles(mlCreateConnectorInput.getBackendRoles()); + connector.setOwner(user); + connector.setAccess(mlCreateConnectorInput.getAccess()); + indexConnector(connector, listener); + } + } catch (MetaDataException e) { + log.error("The masterKey for credential encryption is missing in connector creation"); + listener.onFailure(e); + } catch (Exception e) { + log.error("Failed to create connector " + connectorName, e); + listener.onFailure(e); + } + } + + private void indexConnector(Connector connector, ActionListener listener) { + connector.encrypt(mlEngine::encrypt); + log.info("connector created, indexing into the connector system index"); + mlIndicesHandler.initMLConnectorIndex(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + listener.onFailure(new RuntimeException("No response to create ML Connector index")); + return; + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener indexResponseListener = ActionListener.wrap(r -> { + log.info("Connector saved into index, result:{}, connector id: {}", r.getResult(), r.getId()); + MLCreateConnectorResponse response = new MLCreateConnectorResponse(r.getId()); + listener.onResponse(response); + }, listener::onFailure); + + IndexRequest indexRequest = new IndexRequest(ML_CONNECTOR_INDEX); + indexRequest.source(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.runBefore(indexResponseListener, context::restore)); + } catch (Exception e) { + log.error("Failed to save ML connector", e); + listener.onFailure(e); + } + }, e -> { + log.error("Failed to init ML connector index", e); + listener.onFailure(e); + })); + } + + private void validateRequest4AccessControl(MLCreateConnectorInput input, User user) { + Boolean isAddAllBackendRoles = input.getAddAllBackendRoles(); + if (connectorAccessControlHelper.isAdmin(user)) { + if (Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("Admin can't add all backend roles"); + } + } + AccessMode accessMode = input.getAccess(); + if (accessMode == null) { + if (!CollectionUtils.isEmpty(input.getBackendRoles()) || Boolean.TRUE.equals(isAddAllBackendRoles)) { + input.setAccess(AccessMode.RESTRICTED); + accessMode = AccessMode.RESTRICTED; + } else { + input.setAccess(AccessMode.PRIVATE); + accessMode = AccessMode.PRIVATE; + } + } + if (AccessMode.PUBLIC == accessMode || AccessMode.PRIVATE == accessMode) { + if (!CollectionUtils.isEmpty(input.getBackendRoles()) || Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("You can specify backend roles only for a connector with the restricted access mode."); + } + } + if (AccessMode.RESTRICTED == accessMode) { + if (Boolean.TRUE.equals(isAddAllBackendRoles)) { + if (!CollectionUtils.isEmpty(input.getBackendRoles())) { + throw new IllegalArgumentException("You can't specify backend roles and add all backend roles to true at same time."); + } + if (CollectionUtils.isEmpty(user.getBackendRoles())) { + throw new IllegalArgumentException("You must have at least one backend role to create a connector."); + } + } else { + // check backend_roles parameter + if (CollectionUtils.isEmpty(input.getBackendRoles())) { + throw new IllegalArgumentException( + "You must specify at least one backend role or make the connector public/private for registering it." + ); + } else if (!connectorAccessControlHelper.isAdmin(user) + && !new HashSet<>(user.getBackendRoles()).containsAll(input.getBackendRoles())) { + throw new IllegalArgumentException("You don't have the backend roles specified."); + } + } + } + } + + private void validateSecurityDisabledOrConnectorAccessControlDisabled(MLCreateConnectorInput input) { + if (input.getAccess() != null || input.getAddAllBackendRoles() != null || input.getBackendRoles() != null) { + throw new IllegalArgumentException( + "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled on your cluster." + ); + } + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java b/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java index 07d1902e3d..f56c7915e0 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java @@ -78,8 +78,8 @@ public MLCircuitBreakerService init(Path path) { log.info("Registered ML memory breaker."); registerBreaker(BreakerName.DISK, new DiskCircuitBreaker(path.toString())); log.info("Registered ML disk breaker."); - // Register native memory circuit breaker - registerBreaker(BreakerName.NATIVE_MEMORY, new NativeMemoryCircuitBreaker(this.osService, this.settings, this.clusterService)); + // Register native memory circuit breaker, disabling due to unstability. + // registerBreaker(BreakerName.NATIVE_MEMORY, new NativeMemoryCircuitBreaker(this.osService, this.settings, this.clusterService)); log.info("Registered ML native memory breaker."); return this; diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java index 165736baf2..8b467c2452 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java @@ -5,6 +5,10 @@ package org.opensearch.ml.breaker; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.monitor.jvm.JvmService; /** @@ -15,6 +19,7 @@ public class MemoryCircuitBreaker extends ThresholdCircuitBreaker { private static final String ML_MEMORY_CB = "Memory Circuit Breaker"; public static final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = 85; private final JvmService jvmService; + private volatile Integer jvmHeapMemThreshold = 85; public MemoryCircuitBreaker(JvmService jvmService) { super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD); @@ -26,6 +31,13 @@ public MemoryCircuitBreaker(short threshold, JvmService jvmService) { this.jvmService = jvmService; } + public MemoryCircuitBreaker(Settings settings, ClusterService clusterService, JvmService jvmService) { + super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD); + this.jvmService = jvmService; + this.jvmHeapMemThreshold = ML_COMMONS_JVM_HEAP_MEM_THRESHOLD.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, it -> jvmHeapMemThreshold = it); + } + @Override public String getName() { return ML_MEMORY_CB; diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java new file mode 100644 index 0000000000..fc5f9e95ee --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -0,0 +1,176 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.helper; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; + +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CollectionUtils; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.AbstractConnector; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.utils.MLNodeUtils; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.builder.SearchSourceBuilder; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ConnectorAccessControlHelper { + + private volatile Boolean connectorAccessControlEnabled; + + public ConnectorAccessControlHelper(ClusterService clusterService, Settings settings) { + connectorAccessControlEnabled = ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, it -> connectorAccessControlEnabled = it); + } + + public boolean hasPermission(User user, Connector connector) { + return accessControlNotEnabled(user) + || isAdmin(user) + || previouslyPublicConnector(connector) + || isPublicConnector(connector) + || (isPrivateConnector(connector) && isOwner(user, connector.getOwner())) + || (isRestrictedConnector(connector) && isUserHasBackendRole(user, connector)); + } + + public void validateConnectorAccess(Client client, String connectorId, ActionListener listener) { + User user = RestActionUtils.getUserContext(client); + if (isAdmin(user) || accessControlNotEnabled(user)) { + listener.onResponse(true); + return; + } + GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + boolean hasPermission = hasPermission(user, connector); + wrappedListener.onResponse(hasPermission); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + wrappedListener.onFailure(e); + } + } else { + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find connector:" + connectorId)); + } + }, e -> { + log.error("Fail to get connector", e); + wrappedListener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); + })); + } catch (Exception e) { + log.error("Failed to validate Access for connector:" + connectorId, e); + listener.onFailure(e); + } + + } + + public boolean skipConnectorAccessControl(User user) { + // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin + // Case 2: If Security is enabled and filter is disabled, proceed with search as + // user is already authenticated to hit this API. + // case 3: user is admin which means we don't have to check backend role filtering + return user == null || !connectorAccessControlEnabled || isAdmin(user); + } + + public boolean accessControlNotEnabled(User user) { + return user == null || !connectorAccessControlEnabled; + } + + public boolean isAdmin(User user) { + if (user == null) { + return false; + } + if (CollectionUtils.isEmpty(user.getRoles())) { + return false; + } + return user.getRoles().contains("all_access"); + } + + private boolean isOwner(User owner, User user) { + if (user == null || owner == null) { + return false; + } + return owner.getName().equals(user.getName()); + } + + private boolean isUserHasBackendRole(User user, Connector connector) { + AccessMode modelAccessMode = connector.getAccess(); + return AccessMode.RESTRICTED == modelAccessMode + && (user.getBackendRoles() != null + && connector.getBackendRoles() != null + && connector.getBackendRoles().stream().anyMatch(x -> user.getBackendRoles().contains(x))); + } + + private boolean previouslyPublicConnector(Connector connector) { + return connector.getOwner() == null; + } + + private boolean isPublicConnector(Connector connector) { + return AccessMode.PUBLIC == connector.getAccess(); + } + + private boolean isPrivateConnector(Connector connector) { + return AccessMode.PRIVATE == connector.getAccess(); + } + + private boolean isRestrictedConnector(Connector connector) { + return AccessMode.RESTRICTED == connector.getAccess(); + } + + public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuilder searchSourceBuilder) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PUBLIC.getValue())); + boolQueryBuilder.should(QueryBuilders.termsQuery(AbstractConnector.BACKEND_ROLES_FIELD + ".keyword", user.getBackendRoles())); + + BoolQueryBuilder privateBoolQuery = new BoolQueryBuilder(); + String ownerName = "owner.name.keyword"; + TermQueryBuilder ownerNameTermQuery = QueryBuilders.termQuery(ownerName, user.getName()); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(AbstractConnector.OWNER_FIELD, ownerNameTermQuery, ScoreMode.None); + privateBoolQuery.must(nestedQueryBuilder); + privateBoolQuery.must(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PRIVATE.getValue())); + boolQueryBuilder.should(privateBoolQuery); + QueryBuilder query = searchSourceBuilder.query(); + if (query == null) { + searchSourceBuilder.query(boolQueryBuilder); + } else if (query instanceof BoolQueryBuilder) { + ((BoolQueryBuilder) query).filter(boolQueryBuilder); + } else { + BoolQueryBuilder rewriteQuery = new BoolQueryBuilder(); + rewriteQuery.must(query); + rewriteQuery.filter(boolQueryBuilder); + searchSourceBuilder.query(rewriteQuery); + } + return searchSourceBuilder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java index 3438c11669..668306b763 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java @@ -5,6 +5,9 @@ package org.opensearch.ml.indices; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_SCHEMA_VERSION; @@ -18,7 +21,8 @@ public enum MLIndex { MODEL_GROUP(ML_MODEL_GROUP_INDEX, false, ML_MODEL_GROUP_INDEX_MAPPING, ML_MODEL_GROUP_INDEX_SCHEMA_VERSION), MODEL(ML_MODEL_INDEX, false, ML_MODEL_INDEX_MAPPING, ML_MODEL_INDEX_SCHEMA_VERSION), - TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION); + TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), + CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION); private final String indexName; // whether we use an alias for the index diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java index f0746fdafb..3235d27f29 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java @@ -57,6 +57,10 @@ public void initMLTaskIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.TASK, listener); } + public void initMLConnectorIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.CONNECTOR, listener); + } + public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { String indexName = index.getIndexName(); String mapping = index.getMapping(); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index a9c1ed44d1..e2c7b4692e 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -563,6 +563,7 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX, MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD, + MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES, MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN, MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL, @@ -571,7 +572,9 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED, - MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY + MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY, + MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, + MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 6cb8234cc9..856b819380 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -5,8 +5,13 @@ package org.opensearch.ml.settings; +import java.util.List; +import java.util.function.Function; + import org.opensearch.common.settings.Setting; +import com.google.common.collect.ImmutableList; + public final class MLCommonsSettings { private MLCommonsSettings() {} @@ -67,6 +72,9 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_NATIVE_MEM_THRESHOLD = Setting .intSetting("plugins.ml_commons.native_memory_threshold", 90, 0, 100, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_JVM_HEAP_MEM_THRESHOLD = Setting + .intSetting("plugins.ml_commons.jvm_heap_memory_threshold", 85, 0, 100, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_EXCLUDE_NODE_NAMES = Setting .simpleString("plugins.ml_commons.exclude_nodes._name", Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN = Setting @@ -105,4 +113,21 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_MASTER_SECRET_KEY = Setting .simpleString("plugins.ml_commons.encryption.master_key", "0000000000000000", Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting + .boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting + .listSetting( + "plugins.ml_commons.trusted_connector_endpoints_regex", + ImmutableList + .of( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://api\\.cohere\\.ai/.*$" + ), + Function.identity(), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); }