Skip to content

Commit

Permalink
Add basic connector access control classes (#1055)
Browse files Browse the repository at this point in the history
* Add basic connector access control classes

Signed-off-by: zane-neo <[email protected]>

* Add basic connector access control classes and fformat code

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo authored Jul 10, 2023
1 parent 50e0729 commit 11b72bc
Show file tree
Hide file tree
Showing 5 changed files with 505 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;

Expand All @@ -30,8 +31,8 @@ public class AwsConnector extends HttpConnector {
@Builder(builderMethodName = "awsConnectorBuilder")
public AwsConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode);
List<String> backendRoles, AccessMode accessMode, User owner) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner);
validate();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class HttpConnector extends AbstractConnector {
@Builder
public HttpConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
List<String> backendRoles, AccessMode accessMode, User owner) {
validateProtocol(protocol);
this.name = name;
this.description = description;
Expand All @@ -60,6 +60,7 @@ public HttpConnector(String name, String description, String version, String pro
this.actions = actions;
this.backendRoles = backendRoles;
this.access = accessMode;
this.owner = owner;
}

public HttpConnector(String protocol, XContentParser parser) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.helper;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED;

import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.CollectionUtils;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.connector.AbstractConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class ConnectorAccessControlHelper {

private volatile Boolean connectorAccessControlEnabled;

public ConnectorAccessControlHelper(ClusterService clusterService, Settings settings) {
connectorAccessControlEnabled = ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, it -> connectorAccessControlEnabled = it);
}

public boolean hasPermission(User user, Connector connector) {
return accessControlNotEnabled(user)
|| isAdmin(user)
|| previouslyPublicConnector(connector)
|| isPublicConnector(connector)
|| (isPrivateConnector(connector) && isOwner(user, connector.getOwner()))
|| (isRestrictedConnector(connector) && isUserHasBackendRole(user, connector));
}

public void validateConnectorAccess(Client client, String connectorId, ActionListener<Boolean> listener) {
User user = RestActionUtils.getUserContext(client);
if (isAdmin(user) || accessControlNotEnabled(user)) {
listener.onResponse(true);
return;
}
GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> wrappedListener = ActionListener.runBefore(listener, context::restore);
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Connector connector = Connector.createConnector(parser);
boolean hasPermission = hasPermission(user, connector);
wrappedListener.onResponse(hasPermission);
} catch (Exception e) {
log.error("Failed to parse connector:" + connectorId);
wrappedListener.onFailure(e);
}
} else {
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find connector:" + connectorId));
}
}, e -> {
log.error("Fail to get connector", e);
wrappedListener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId));
}));
} catch (Exception e) {
log.error("Failed to validate Access for connector:" + connectorId, e);
listener.onFailure(e);
}

}

public boolean skipConnectorAccessControl(User user) {
// Case 1: user == null when 1. Security is disabled. 2. When user is super-admin
// Case 2: If Security is enabled and filter is disabled, proceed with search as
// user is already authenticated to hit this API.
// case 3: user is admin which means we don't have to check backend role filtering
return user == null || !connectorAccessControlEnabled || isAdmin(user);
}

public boolean accessControlNotEnabled(User user) {
return user == null || !connectorAccessControlEnabled;
}

public boolean isAdmin(User user) {
if (user == null) {
return false;
}
if (CollectionUtils.isEmpty(user.getRoles())) {
return false;
}
return user.getRoles().contains("all_access");
}

private boolean isOwner(User owner, User user) {
if (user == null || owner == null) {
return false;
}
return owner.getName().equals(user.getName());
}

private boolean isUserHasBackendRole(User user, Connector connector) {
AccessMode modelAccessMode = connector.getAccess();
return AccessMode.RESTRICTED == modelAccessMode
&& (user.getBackendRoles() != null
&& connector.getBackendRoles() != null
&& connector.getBackendRoles().stream().anyMatch(x -> user.getBackendRoles().contains(x)));
}

private boolean previouslyPublicConnector(Connector connector) {
return connector.getOwner() == null;
}

private boolean isPublicConnector(Connector connector) {
return AccessMode.PUBLIC == connector.getAccess();
}

private boolean isPrivateConnector(Connector connector) {
return AccessMode.PRIVATE == connector.getAccess();
}

private boolean isRestrictedConnector(Connector connector) {
return AccessMode.RESTRICTED == connector.getAccess();
}

public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuilder searchSourceBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PUBLIC.getValue()));
boolQueryBuilder.should(QueryBuilders.termsQuery(AbstractConnector.BACKEND_ROLES_FIELD + ".keyword", user.getBackendRoles()));

BoolQueryBuilder privateBoolQuery = new BoolQueryBuilder();
String ownerName = "owner.name.keyword";
TermQueryBuilder ownerNameTermQuery = QueryBuilders.termQuery(ownerName, user.getName());
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(AbstractConnector.OWNER_FIELD, ownerNameTermQuery, ScoreMode.None);
privateBoolQuery.must(nestedQueryBuilder);
privateBoolQuery.must(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PRIVATE.getValue()));
boolQueryBuilder.should(privateBoolQuery);
QueryBuilder query = searchSourceBuilder.query();
if (query == null) {
searchSourceBuilder.query(boolQueryBuilder);
} else if (query instanceof BoolQueryBuilder) {
((BoolQueryBuilder) query).filter(boolQueryBuilder);
} else {
BoolQueryBuilder rewriteQuery = new BoolQueryBuilder();
rewriteQuery.must(query);
rewriteQuery.filter(boolQueryBuilder);
searchSourceBuilder.query(rewriteQuery);
}
return searchSourceBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

package org.opensearch.ml.settings;

import java.util.List;
import java.util.function.Function;

import org.opensearch.common.settings.Setting;

import com.google.common.collect.ImmutableList;

public final class MLCommonsSettings {

private MLCommonsSettings() {}
Expand Down Expand Up @@ -105,4 +110,21 @@ private MLCommonsSettings() {}

public static final Setting<String> ML_COMMONS_MASTER_SECRET_KEY = Setting
.simpleString("plugins.ml_commons.encryption.master_key", "0000000000000000", Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<List<String>> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting
.listSetting(
"plugins.ml_commons.trusted_connector_endpoints_regex",
ImmutableList
.of(
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$"
),
Function.identity(),
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Loading

0 comments on commit 11b72bc

Please sign in to comment.