From 11b72bc0935372a865aa92d92219530c5cc5b3a7 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 10 Jul 2023 09:54:05 +0800 Subject: [PATCH] Add basic connector access control classes (#1055) * Add basic connector access control classes Signed-off-by: zane-neo * Add basic connector access control classes and fformat code Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --- .../ml/common/connector/AwsConnector.java | 5 +- .../ml/common/connector/HttpConnector.java | 3 +- .../helper/ConnectorAccessControlHelper.java | 176 ++++++++++ .../ml/settings/MLCommonsSettings.java | 22 ++ .../ConnectorAccessControlHelperTests.java | 302 ++++++++++++++++++ 5 files changed, 505 insertions(+), 3 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java create mode 100644 plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java index 7d2c4478f0..d9f2205a60 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java @@ -11,6 +11,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; @@ -30,8 +31,8 @@ public class AwsConnector extends HttpConnector { @Builder(builderMethodName = "awsConnectorBuilder") public AwsConnector(String name, String description, String version, String protocol, Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode) { - super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode); + List backendRoles, AccessMode accessMode, User owner) { + super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner); validate(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 212858f074..e00c682cfe 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -49,7 +49,7 @@ public class HttpConnector extends AbstractConnector { @Builder public HttpConnector(String name, String description, String version, String protocol, Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode) { + List backendRoles, AccessMode accessMode, User owner) { validateProtocol(protocol); this.name = name; this.description = description; @@ -60,6 +60,7 @@ public HttpConnector(String name, String description, String version, String pro this.actions = actions; this.backendRoles = backendRoles; this.access = accessMode; + this.owner = owner; } public HttpConnector(String protocol, XContentParser parser) throws IOException { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java new file mode 100644 index 0000000000..fc5f9e95ee --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -0,0 +1,176 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.helper; + +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; + +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.CollectionUtils; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.AbstractConnector; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.utils.MLNodeUtils; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.builder.SearchSourceBuilder; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ConnectorAccessControlHelper { + + private volatile Boolean connectorAccessControlEnabled; + + public ConnectorAccessControlHelper(ClusterService clusterService, Settings settings) { + connectorAccessControlEnabled = ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, it -> connectorAccessControlEnabled = it); + } + + public boolean hasPermission(User user, Connector connector) { + return accessControlNotEnabled(user) + || isAdmin(user) + || previouslyPublicConnector(connector) + || isPublicConnector(connector) + || (isPrivateConnector(connector) && isOwner(user, connector.getOwner())) + || (isRestrictedConnector(connector) && isUserHasBackendRole(user, connector)); + } + + public void validateConnectorAccess(Client client, String connectorId, ActionListener listener) { + User user = RestActionUtils.getUserContext(client); + if (isAdmin(user) || accessControlNotEnabled(user)) { + listener.onResponse(true); + return; + } + GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + boolean hasPermission = hasPermission(user, connector); + wrappedListener.onResponse(hasPermission); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + wrappedListener.onFailure(e); + } + } else { + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find connector:" + connectorId)); + } + }, e -> { + log.error("Fail to get connector", e); + wrappedListener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); + })); + } catch (Exception e) { + log.error("Failed to validate Access for connector:" + connectorId, e); + listener.onFailure(e); + } + + } + + public boolean skipConnectorAccessControl(User user) { + // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin + // Case 2: If Security is enabled and filter is disabled, proceed with search as + // user is already authenticated to hit this API. + // case 3: user is admin which means we don't have to check backend role filtering + return user == null || !connectorAccessControlEnabled || isAdmin(user); + } + + public boolean accessControlNotEnabled(User user) { + return user == null || !connectorAccessControlEnabled; + } + + public boolean isAdmin(User user) { + if (user == null) { + return false; + } + if (CollectionUtils.isEmpty(user.getRoles())) { + return false; + } + return user.getRoles().contains("all_access"); + } + + private boolean isOwner(User owner, User user) { + if (user == null || owner == null) { + return false; + } + return owner.getName().equals(user.getName()); + } + + private boolean isUserHasBackendRole(User user, Connector connector) { + AccessMode modelAccessMode = connector.getAccess(); + return AccessMode.RESTRICTED == modelAccessMode + && (user.getBackendRoles() != null + && connector.getBackendRoles() != null + && connector.getBackendRoles().stream().anyMatch(x -> user.getBackendRoles().contains(x))); + } + + private boolean previouslyPublicConnector(Connector connector) { + return connector.getOwner() == null; + } + + private boolean isPublicConnector(Connector connector) { + return AccessMode.PUBLIC == connector.getAccess(); + } + + private boolean isPrivateConnector(Connector connector) { + return AccessMode.PRIVATE == connector.getAccess(); + } + + private boolean isRestrictedConnector(Connector connector) { + return AccessMode.RESTRICTED == connector.getAccess(); + } + + public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuilder searchSourceBuilder) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PUBLIC.getValue())); + boolQueryBuilder.should(QueryBuilders.termsQuery(AbstractConnector.BACKEND_ROLES_FIELD + ".keyword", user.getBackendRoles())); + + BoolQueryBuilder privateBoolQuery = new BoolQueryBuilder(); + String ownerName = "owner.name.keyword"; + TermQueryBuilder ownerNameTermQuery = QueryBuilders.termQuery(ownerName, user.getName()); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(AbstractConnector.OWNER_FIELD, ownerNameTermQuery, ScoreMode.None); + privateBoolQuery.must(nestedQueryBuilder); + privateBoolQuery.must(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PRIVATE.getValue())); + boolQueryBuilder.should(privateBoolQuery); + QueryBuilder query = searchSourceBuilder.query(); + if (query == null) { + searchSourceBuilder.query(boolQueryBuilder); + } else if (query instanceof BoolQueryBuilder) { + ((BoolQueryBuilder) query).filter(boolQueryBuilder); + } else { + BoolQueryBuilder rewriteQuery = new BoolQueryBuilder(); + rewriteQuery.must(query); + rewriteQuery.filter(boolQueryBuilder); + searchSourceBuilder.query(rewriteQuery); + } + return searchSourceBuilder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 6cb8234cc9..ea122aea77 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -5,8 +5,13 @@ package org.opensearch.ml.settings; +import java.util.List; +import java.util.function.Function; + import org.opensearch.common.settings.Setting; +import com.google.common.collect.ImmutableList; + public final class MLCommonsSettings { private MLCommonsSettings() {} @@ -105,4 +110,21 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_MASTER_SECRET_KEY = Setting .simpleString("plugins.ml_commons.encryption.master_key", "0000000000000000", Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting + .boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting + .listSetting( + "plugins.ml_commons.trusted_connector_endpoints_regex", + ImmutableList + .of( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://api\\.cohere\\.ai/.*$" + ), + Function.identity(), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); } diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java new file mode 100644 index 0000000000..933cd4d825 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -0,0 +1,302 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.helper; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; +import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.collect.ImmutableList; + +public class ConnectorAccessControlHelperTests extends OpenSearchTestCase { + + @Mock + ClusterService clusterService; + + @Mock + Client client; + + @Mock + private ActionListener actionListener; + + @Mock + private ThreadPool threadPool; + + ThreadContext threadContext; + + private ConnectorAccessControlHelper connectorAccessControlHelper; + + private GetResponse getResponse; + + private User user; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + threadContext = new ThreadContext(settings); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + user = User.parse("mockUser|role-1,role-2|null"); + + getResponse = createGetResponse(null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + } + + public void test_hasPermission_user_null_return_true() { + HttpConnector httpConnector = mock(HttpConnector.class); + boolean hasPermission = connectorAccessControlHelper.hasPermission(null, httpConnector); + assertTrue(hasPermission); + } + + public void test_hasPermission_connectorAccessControl_not_enabled_return_true() { + HttpConnector httpConnector = mock(HttpConnector.class); + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ConnectorAccessControlHelper connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); + assertTrue(hasPermission); + } + + public void test_hasPermission_connectorOwner_is_null_return_true() { + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getOwner()).thenReturn(null); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); + assertTrue(hasPermission); + } + + public void test_hasPermission_user_is_admin_return_true() { + User user = User.parse("admin|role-1|all_access"); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, mock(HttpConnector.class)); + assertTrue(hasPermission); + } + + public void test_hasPermission_connector_isPublic_return_true() { + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.PUBLIC); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); + assertTrue(hasPermission); + } + + public void test_hasPermission_connector_isPrivate_userIsOwner_return_true() { + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); + when(httpConnector.getOwner()).thenReturn(user); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); + assertTrue(hasPermission); + } + + public void test_hasPermission_connector_isPrivate_userIsNotOwner_return_false() { + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); + User user1 = User.parse(USER_STRING); + when(httpConnector.getOwner()).thenReturn(user); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user1, httpConnector); + assertFalse(hasPermission); + } + + public void test_hasPermission_connector_isRestricted_userHasBackendRole_return_true() { + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); + when(httpConnector.getBackendRoles()).thenReturn(ImmutableList.of("role-1")); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); + assertTrue(hasPermission); + } + + public void test_hasPermission_connector_isRestricted_userNotHasBackendRole_return_false() { + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); + when(httpConnector.getBackendRoles()).thenReturn(ImmutableList.of("role-3")); + when(httpConnector.getOwner()).thenReturn(user); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); + assertFalse(hasPermission); + } + + public void test_validateConnectorAccess_user_isAdmin_return_true() { + String userString = "admin|role-1|all_access"; + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + ThreadContext threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString); + + connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + verify(actionListener).onResponse(true); + } + + public void test_validateConnectorAccess_user_isNotAdmin_hasNoBackendRole_return_false() { + GetResponse getResponse = createGetResponse(ImmutableList.of("role-3")); + Client client = mock(Client.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + + connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + verify(actionListener).onResponse(false); + } + + public void test_validateConnectorAccess_user_isNotAdmin_hasBackendRole_return_true() { + connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + verify(actionListener).onResponse(true); + } + + public void test_validateConnectorAccess_connectorNotFound_return_false() { + Client client = mock(Client.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + + connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + verify(actionListener, times(1)).onFailure(any(MLResourceNotFoundException.class)); + } + + public void test_validateConnectorAccess_searchConnectorException_return_false() { + Client client = mock(Client.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).get(any(), any()); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + + connectorAccessControlHelper.validateConnectorAccess(client, "anyId", actionListener); + verify(actionListener).onFailure(any(IllegalStateException.class)); + } + + public void test_skipConnectorAccessControl_userIsNull_return_true() { + boolean skip = connectorAccessControlHelper.skipConnectorAccessControl(null); + assertTrue(skip); + } + + public void test_skipConnectorAccessControl_connectorAccessControl_notEnabled_return_true() { + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ConnectorAccessControlHelper connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + boolean skip = connectorAccessControlHelper.skipConnectorAccessControl(user); + assertTrue(skip); + } + + public void test_skipConnectorAccessControl_userIsAdmin_return_true() { + User user = User.parse("admin|role-1|all_access"); + boolean skip = connectorAccessControlHelper.skipConnectorAccessControl(user); + assertTrue(skip); + } + + public void test_accessControlNotEnabled_connectorAccessControl_notEnabled_return_true() { + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ConnectorAccessControlHelper connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); + boolean skip = connectorAccessControlHelper.accessControlNotEnabled(user); + assertTrue(skip); + } + + public void test_accessControlNotEnabled_userIsNull_return_true() { + boolean notEnabled = connectorAccessControlHelper.accessControlNotEnabled(null); + assertTrue(notEnabled); + } + + public void test_addUserBackendRolesFilter_nullQuery() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + SearchSourceBuilder result = connectorAccessControlHelper.addUserBackendRolesFilter(user, searchSourceBuilder); + assertNotNull(result); + } + + public void test_addUserBackendRolesFilter_boolQuery() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(new BoolQueryBuilder()); + SearchSourceBuilder result = connectorAccessControlHelper.addUserBackendRolesFilter(user, searchSourceBuilder); + assertEquals("bool", result.query().getName()); + } + + public void test_addUserBackendRolesFilter_nonBoolQuery() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(new MatchAllQueryBuilder()); + SearchSourceBuilder result = connectorAccessControlHelper.addUserBackendRolesFilter(user, searchSourceBuilder); + assertEquals("bool", result.query().getName()); + } + + private GetResponse createGetResponse(List backendRoles) { + HttpConnector httpConnector = HttpConnector + .builder() + .name("testConnector") + .protocol(ConnectorProtocols.HTTP) + .owner(user) + .description("This is test connector") + .backendRoles(Optional.ofNullable(backendRoles).orElse(ImmutableList.of("role-1"))) + .accessMode(AccessMode.RESTRICTED) + .build(); + XContentBuilder content = null; + try { + content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + } catch (IOException e) { + throw new RuntimeException(e); + } + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult(CommonValue.ML_MODEL_GROUP_INDEX, "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } +}