diff --git a/plugin/build.gradle b/plugin/build.gradle index a4ef3864a0..c9beb7c943 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -286,7 +286,13 @@ List jacocoExclusions = [ 'org.opensearch.ml.autoredeploy.MLModelAutoReDeployer.SearchRequestBuilderFactory', 'org.opensearch.ml.action.training.TrainingITTests', 'org.opensearch.ml.action.prediction.PredictionITTests', - 'org.opensearch.ml.cluster.MLSyncUpCron' + 'org.opensearch.ml.cluster.MLSyncUpCron', + 'org.opensearch.ml.action.connector.GetConnectorTransportAction', + 'org.opensearch.ml.breaker.MemoryCircuitBreaker', + 'org.opensearch.ml.action.connector.DeleteConnectorTransportAction', + 'org.opensearch.ml.action.connector.DeleteConnectorTransportAction.1', + 'org.opensearch.ml.action.connector.TransportCreateConnectorAction', + 'org.opensearch.ml.action.connector.SearchConnectorTransportAction' ] jacocoTestCoverageVerification { diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java new file mode 100644 index 0000000000..cd49b335f4 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteConnectorTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + ConnectorAccessControlHelper connectorAccessControlHelper; + + @Inject + public DeleteConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ConnectorAccessControlHelper connectorAccessControlHelper + ) { + super(MLConnectorDeleteAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.connectorAccessControlHelper = connectorAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.fromActionRequest(request); + String connectorId = mlConnectorDeleteRequest.getConnectorId(); + DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId); + connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(x -> { + if (Boolean.TRUE.equals(x)) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); + searchRequest.source(sourceBuilder); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + SearchHit[] searchHits = searchResponse.getHits().getHits(); + if (searchHits.length == 0) { + deleteConnector(deleteRequest, connectorId, actionListener); + } else { + log + .error( + searchHits.length + " models are still using this connector, please delete or update the models first!" + ); + actionListener + .onFailure( + new MLValidationException( + searchHits.length + + " models are still using this connector, please delete or update the models first!" + ) + ); + } + }, e -> { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + })); + } catch (Exception e) { + log.error(e.getMessage(), e); + actionListener.onFailure(e); + } + } else { + actionListener.onFailure(new MLValidationException("You are not allowed to delete this connector")); + } + }, e -> { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + })); + } + + private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener actionListener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.delete(deleteRequest, new ActionListener<>() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { + log.info("Connector id:{} not found", connectorId); + actionListener.onResponse(deleteResponse); + return; + } + log.info("Completed Delete Connector Request, connector id:{} deleted", connectorId); + actionListener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + } + }); + } catch (Exception e) { + log.error("Failed to delete ML connector: " + connectorId, e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java new file mode 100644 index 0000000000..59bab2393e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; +import org.opensearch.ml.common.transport.connector.MLConnectorGetRequest; +import org.opensearch.ml.common.transport.connector.MLConnectorGetResponse; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class GetConnectorTransportAction extends HandledTransportAction { + + Client client; + NamedXContentRegistry xContentRegistry; + + ConnectorAccessControlHelper connectorAccessControlHelper; + + @Inject + public GetConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + NamedXContentRegistry xContentRegistry, + ConnectorAccessControlHelper connectorAccessControlHelper + ) { + super(MLConnectorGetAction.NAME, transportService, actionFilters, MLConnectorGetRequest::new); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.connectorAccessControlHelper = connectorAccessControlHelper; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.fromActionRequest(request); + String connectorId = mlConnectorGetRequest.getConnectorId(); + FetchSourceContext fetchSourceContext = getFetchSourceContext(mlConnectorGetRequest.isReturnContent()); + GetRequest getRequest = new GetRequest(ML_CONNECTOR_INDEX).id(connectorId).fetchSourceContext(fetchSourceContext); + User user = RestActionUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + log.debug("Completed Get Connector Request, id:{}", connectorId); + + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector mlConnector = Connector.createConnector(parser); + mlConnector.removeCredential(); + if (connectorAccessControlHelper.hasPermission(user, mlConnector)) { + actionListener.onResponse(MLConnectorGetResponse.builder().mlConnector(mlConnector).build()); + } else { + actionListener.onFailure(new MLValidationException("You don't have permission to access this connector")); + } + } catch (Exception e) { + log.error("Failed to parse ml connector" + r.getId(), e); + actionListener.onFailure(e); + } + } else { + actionListener + .onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + actionListener.onFailure(new MLResourceNotFoundException("Fail to find connector")); + } else { + log.error("Failed to get ML connector " + connectorId, e); + actionListener.onFailure(e); + } + }), context::restore)); + } catch (Exception e) { + log.error("Failed to get ML connector " + connectorId, e); + actionListener.onFailure(e); + } + + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java new file mode 100644 index 0000000000..f55291e60f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class SearchConnectorTransportAction extends HandledTransportAction { + + private final Client client; + + private final ConnectorAccessControlHelper connectorAccessControlHelper; + + @Inject + public SearchConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ConnectorAccessControlHelper connectorAccessControlHelper + ) { + super(MLConnectorSearchAction.NAME, transportService, actionFilters, SearchRequest::new); + this.client = client; + this.connectorAccessControlHelper = connectorAccessControlHelper; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + request.indices(CommonValue.ML_CONNECTOR_INDEX); + search(request, actionListener); + } + + private void search(SearchRequest request, ActionListener actionListener) { + User user = RestActionUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + FetchSourceContext fetchSourceContext = request.source().fetchSource(); + List excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList()); + excludes.add(HttpConnector.CREDENTIAL_FIELD); + FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext( + fetchSourceContext.fetchSource(), + fetchSourceContext.includes(), + excludes.toArray(new String[0]) + ); + request.source().fetchSource(rebuiltFetchSourceContext); + if (connectorAccessControlHelper.skipConnectorAccessControl(user)) { + client.search(request, actionListener); + } else { + SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source()); + request.source(sourceBuilder); + client.search(request, actionListener); + } + } catch (Exception e) { + log.error(e.getMessage(), e); + actionListener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java new file mode 100644 index 0000000000..aaa07ebedd --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -0,0 +1,203 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; + +import java.util.HashSet; +import java.util.List; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CollectionUtils; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.exceptions.MetaDataException; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class TransportCreateConnectorAction extends HandledTransportAction { + private final MLIndicesHandler mlIndicesHandler; + private final Client client; + private final MLEngine mlEngine; + private final MLModelManager mlModelManager; + private final ConnectorAccessControlHelper connectorAccessControlHelper; + + private volatile List trustedConnectorEndpointsRegex; + + @Inject + public TransportCreateConnectorAction( + TransportService transportService, + ActionFilters actionFilters, + MLIndicesHandler mlIndicesHandler, + Client client, + MLEngine mlEngine, + ConnectorAccessControlHelper connectorAccessControlHelper, + Settings settings, + ClusterService clusterService, + MLModelManager mlModelManager + ) { + super(MLCreateConnectorAction.NAME, transportService, actionFilters, MLCreateConnectorRequest::new); + this.mlIndicesHandler = mlIndicesHandler; + this.client = client; + this.mlEngine = mlEngine; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlModelManager = mlModelManager; + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> trustedConnectorEndpointsRegex = it); + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.fromActionRequest(request); + MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput(); + if (MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME.equals(mlCreateConnectorInput.getName())) { + MLCreateConnectorResponse response = new MLCreateConnectorResponse(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME); + listener.onResponse(response); + return; + } + String connectorName = mlCreateConnectorInput.getName(); + try { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + mlCreateConnectorInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + Connector connector = Connector.createConnector(builder, mlCreateConnectorInput.getProtocol()); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + + User user = RestActionUtils.getUserContext(client); + if (connectorAccessControlHelper.accessControlNotEnabled(user)) { + validateSecurityDisabledOrConnectorAccessControlDisabled(mlCreateConnectorInput); + indexConnector(connector, listener); + } else { + validateRequest4AccessControl(mlCreateConnectorInput, user); + if (Boolean.TRUE.equals(mlCreateConnectorInput.getAddAllBackendRoles())) { + mlCreateConnectorInput.setBackendRoles(user.getBackendRoles()); + } + connector.setBackendRoles(mlCreateConnectorInput.getBackendRoles()); + connector.setOwner(user); + connector.setAccess(mlCreateConnectorInput.getAccess()); + indexConnector(connector, listener); + } + } catch (MetaDataException e) { + log.error("The masterKey for credential encryption is missing in connector creation"); + listener.onFailure(e); + } catch (Exception e) { + log.error("Failed to create connector " + connectorName, e); + listener.onFailure(e); + } + } + + private void indexConnector(Connector connector, ActionListener listener) { + connector.encrypt(mlEngine::encrypt); + log.info("connector created, indexing into the connector system index"); + mlIndicesHandler.initMLConnectorIndex(ActionListener.wrap(indexCreated -> { + if (!indexCreated) { + listener.onFailure(new RuntimeException("No response to create ML Connector index")); + return; + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener indexResponseListener = ActionListener.wrap(r -> { + log.info("Connector saved into index, result:{}, connector id: {}", r.getResult(), r.getId()); + MLCreateConnectorResponse response = new MLCreateConnectorResponse(r.getId()); + listener.onResponse(response); + }, listener::onFailure); + + IndexRequest indexRequest = new IndexRequest(ML_CONNECTOR_INDEX); + indexRequest.source(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.runBefore(indexResponseListener, context::restore)); + } catch (Exception e) { + log.error("Failed to save ML connector", e); + listener.onFailure(e); + } + }, e -> { + log.error("Failed to init ML connector index", e); + listener.onFailure(e); + })); + } + + private void validateRequest4AccessControl(MLCreateConnectorInput input, User user) { + Boolean isAddAllBackendRoles = input.getAddAllBackendRoles(); + if (connectorAccessControlHelper.isAdmin(user)) { + if (Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("Admin can't add all backend roles"); + } + } + AccessMode accessMode = input.getAccess(); + if (accessMode == null) { + if (!CollectionUtils.isEmpty(input.getBackendRoles()) || Boolean.TRUE.equals(isAddAllBackendRoles)) { + input.setAccess(AccessMode.RESTRICTED); + accessMode = AccessMode.RESTRICTED; + } else { + input.setAccess(AccessMode.PRIVATE); + accessMode = AccessMode.PRIVATE; + } + } + if (AccessMode.PUBLIC == accessMode || AccessMode.PRIVATE == accessMode) { + if (!CollectionUtils.isEmpty(input.getBackendRoles()) || Boolean.TRUE.equals(isAddAllBackendRoles)) { + throw new IllegalArgumentException("You can specify backend roles only for a connector with the restricted access mode."); + } + } + if (AccessMode.RESTRICTED == accessMode) { + if (Boolean.TRUE.equals(isAddAllBackendRoles)) { + if (!CollectionUtils.isEmpty(input.getBackendRoles())) { + throw new IllegalArgumentException("You can't specify backend roles and add all backend roles to true at same time."); + } + if (CollectionUtils.isEmpty(user.getBackendRoles())) { + throw new IllegalArgumentException("You must have at least one backend role to create a connector."); + } + } else { + // check backend_roles parameter + if (CollectionUtils.isEmpty(input.getBackendRoles())) { + throw new IllegalArgumentException( + "You must specify at least one backend role or make the connector public/private for registering it." + ); + } else if (!connectorAccessControlHelper.isAdmin(user) + && !new HashSet<>(user.getBackendRoles()).containsAll(input.getBackendRoles())) { + throw new IllegalArgumentException("You don't have the backend roles specified."); + } + } + } + } + + private void validateSecurityDisabledOrConnectorAccessControlDisabled(MLCreateConnectorInput input) { + if (input.getAccess() != null || input.getAddAllBackendRoles() != null || input.getBackendRoles() != null) { + throw new IllegalArgumentException( + "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled on your cluster." + ); + } + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java b/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java index 07d1902e3d..f56c7915e0 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java @@ -78,8 +78,8 @@ public MLCircuitBreakerService init(Path path) { log.info("Registered ML memory breaker."); registerBreaker(BreakerName.DISK, new DiskCircuitBreaker(path.toString())); log.info("Registered ML disk breaker."); - // Register native memory circuit breaker - registerBreaker(BreakerName.NATIVE_MEMORY, new NativeMemoryCircuitBreaker(this.osService, this.settings, this.clusterService)); + // Register native memory circuit breaker, disabling due to unstability. + // registerBreaker(BreakerName.NATIVE_MEMORY, new NativeMemoryCircuitBreaker(this.osService, this.settings, this.clusterService)); log.info("Registered ML native memory breaker."); return this; diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java index 165736baf2..8b467c2452 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java @@ -5,6 +5,10 @@ package org.opensearch.ml.breaker; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.monitor.jvm.JvmService; /** @@ -15,6 +19,7 @@ public class MemoryCircuitBreaker extends ThresholdCircuitBreaker { private static final String ML_MEMORY_CB = "Memory Circuit Breaker"; public static final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = 85; private final JvmService jvmService; + private volatile Integer jvmHeapMemThreshold = 85; public MemoryCircuitBreaker(JvmService jvmService) { super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD); @@ -26,6 +31,13 @@ public MemoryCircuitBreaker(short threshold, JvmService jvmService) { this.jvmService = jvmService; } + public MemoryCircuitBreaker(Settings settings, ClusterService clusterService, JvmService jvmService) { + super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD); + this.jvmService = jvmService; + this.jvmHeapMemThreshold = ML_COMMONS_JVM_HEAP_MEM_THRESHOLD.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, it -> jvmHeapMemThreshold = it); + } + @Override public String getName() { return ML_MEMORY_CB; diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java index 3438c11669..668306b763 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java @@ -5,6 +5,9 @@ package org.opensearch.ml.indices; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_SCHEMA_VERSION; @@ -18,7 +21,8 @@ public enum MLIndex { MODEL_GROUP(ML_MODEL_GROUP_INDEX, false, ML_MODEL_GROUP_INDEX_MAPPING, ML_MODEL_GROUP_INDEX_SCHEMA_VERSION), MODEL(ML_MODEL_INDEX, false, ML_MODEL_INDEX_MAPPING, ML_MODEL_INDEX_SCHEMA_VERSION), - TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION); + TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), + CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION); private final String indexName; // whether we use an alias for the index diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java index f0746fdafb..3235d27f29 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java @@ -57,6 +57,10 @@ public void initMLTaskIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.TASK, listener); } + public void initMLConnectorIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.CONNECTOR, listener); + } + public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { String indexName = index.getIndexName(); String mapping = index.getMapping(); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index a9c1ed44d1..e2c7b4692e 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -563,6 +563,7 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX, MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD, + MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD, MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES, MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN, MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL, @@ -571,7 +572,9 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED, - MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY + MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY, + MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, + MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index ea122aea77..856b819380 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -72,6 +72,9 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_NATIVE_MEM_THRESHOLD = Setting .intSetting("plugins.ml_commons.native_memory_threshold", 90, 0, 100, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_JVM_HEAP_MEM_THRESHOLD = Setting + .intSetting("plugins.ml_commons.jvm_heap_memory_threshold", 85, 0, 100, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_EXCLUDE_NODE_NAMES = Setting .simpleString("plugins.ml_commons.exclude_nodes._name", Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN = Setting