Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic connector access control classes #1055

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;

Expand All @@ -28,8 +29,8 @@ public class AwsConnector extends HttpConnector {
@Builder(builderMethodName = "awsConnectorBuilder")
public AwsConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode);
List<String> backendRoles, AccessMode accessMode, User owner) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner);
validate();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class HttpConnector extends AbstractConnector {
@Builder
public HttpConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
List<String> backendRoles, AccessMode accessMode, User owner) {
validateProtocol(protocol);
this.name = name;
this.description = description;
Expand All @@ -58,6 +58,7 @@ public HttpConnector(String name, String description, String version, String pro
this.actions = actions;
this.backendRoles = backendRoles;
this.access = accessMode;
this.owner = owner;
}

public HttpConnector(String protocol, XContentParser parser) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.helper;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED;

import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.CollectionUtils;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.connector.AbstractConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.builder.SearchSourceBuilder;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class ConnectorAccessControlHelper {

private volatile Boolean connectorAccessControlEnabled;

public ConnectorAccessControlHelper(ClusterService clusterService, Settings settings) {
connectorAccessControlEnabled = ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, it -> connectorAccessControlEnabled = it);
}

public boolean hasPermission(User user, Connector connector) {
return accessControlNotEnabled(user)
|| isAdmin(user)
|| previouslyPublicConnector(connector)
|| isPublicConnector(connector)
|| (isPrivateConnector(connector) && isOwner(user, connector.getOwner()))
|| (isRestrictedConnector(connector) && isUserHasBackendRole(user, connector));
}

public void validateConnectorAccess(Client client, String connectorId, ActionListener<Boolean> listener) {
User user = RestActionUtils.getUserContext(client);
if (isAdmin(user) || accessControlNotEnabled(user)) {
listener.onResponse(true);
return;
}
GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> wrappedListener = ActionListener.runBefore(listener, context::restore);
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Connector connector = Connector.createConnector(parser);
boolean hasPermission = hasPermission(user, connector);
wrappedListener.onResponse(hasPermission);
} catch (Exception e) {
log.error("Failed to parse connector:" + connectorId);
wrappedListener.onFailure(e);
}
} else {
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find connector:" + connectorId));
}
}, e -> {
log.error("Fail to get connector", e);
wrappedListener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId));
}));
} catch (Exception e) {
log.error("Failed to validate Access for connector:" + connectorId, e);
listener.onFailure(e);
}

}

public boolean skipConnectorAccessControl(User user) {
// Case 1: user == null when 1. Security is disabled. 2. When user is super-admin
// Case 2: If Security is enabled and filter is disabled, proceed with search as
// user is already authenticated to hit this API.
// case 3: user is admin which means we don't have to check backend role filtering
return user == null || !connectorAccessControlEnabled || isAdmin(user);
}

public boolean accessControlNotEnabled(User user) {
return user == null || !connectorAccessControlEnabled;
}

public boolean isAdmin(User user) {
if (user == null) {
return false;
}
if (CollectionUtils.isEmpty(user.getRoles())) {
return false;
}
return user.getRoles().contains("all_access");
}

private boolean isOwner(User owner, User user) {
if (user == null || owner == null) {
return false;
}
return owner.getName().equals(user.getName());
}

private boolean isUserHasBackendRole(User user, Connector connector) {
AccessMode modelAccessMode = connector.getAccess();
return AccessMode.RESTRICTED == modelAccessMode
&& (user.getBackendRoles() != null
&& connector.getBackendRoles() != null
&& connector.getBackendRoles().stream().anyMatch(x -> user.getBackendRoles().contains(x)));
}

private boolean previouslyPublicConnector(Connector connector) {
return connector.getOwner() == null;
}

private boolean isPublicConnector(Connector connector) {
return AccessMode.PUBLIC == connector.getAccess();
}

private boolean isPrivateConnector(Connector connector) {
return AccessMode.PRIVATE == connector.getAccess();
}

private boolean isRestrictedConnector(Connector connector) {
return AccessMode.RESTRICTED == connector.getAccess();
}

public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuilder searchSourceBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PUBLIC.getValue()));
boolQueryBuilder.should(QueryBuilders.termsQuery(AbstractConnector.BACKEND_ROLES_FIELD + ".keyword", user.getBackendRoles()));

BoolQueryBuilder privateBoolQuery = new BoolQueryBuilder();
String ownerName = "owner.name.keyword";
TermQueryBuilder ownerNameTermQuery = QueryBuilders.termQuery(ownerName, user.getName());
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(AbstractConnector.OWNER_FIELD, ownerNameTermQuery, ScoreMode.None);
privateBoolQuery.must(nestedQueryBuilder);
privateBoolQuery.must(QueryBuilders.termQuery(AbstractConnector.ACCESS_FIELD, AccessMode.PRIVATE.getValue()));
boolQueryBuilder.should(privateBoolQuery);
QueryBuilder query = searchSourceBuilder.query();
if (query == null) {
searchSourceBuilder.query(boolQueryBuilder);
} else if (query instanceof BoolQueryBuilder) {
((BoolQueryBuilder) query).filter(boolQueryBuilder);
} else {
BoolQueryBuilder rewriteQuery = new BoolQueryBuilder();
rewriteQuery.must(query);
rewriteQuery.filter(boolQueryBuilder);
searchSourceBuilder.query(rewriteQuery);
}
return searchSourceBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

package org.opensearch.ml.settings;

import java.util.List;
import java.util.function.Function;

import org.opensearch.common.settings.Setting;

import com.google.common.collect.ImmutableList;

public final class MLCommonsSettings {

private MLCommonsSettings() {}
Expand Down Expand Up @@ -105,4 +110,21 @@ private MLCommonsSettings() {}

public static final Setting<String> ML_COMMONS_MASTER_SECRET_KEY = Setting
.simpleString("plugins.ml_commons.encryption.master_key", "0000000000000000", Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<List<String>> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting
.listSetting(
"plugins.ml_commons.trusted_connector_endpoints_regex",
ImmutableList
.of(
"^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$"
),
Function.identity(),
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
}
Loading