Skip to content

Commit

Permalink
Allow returning multiple filters and masks in ConnectorAccessControl
Browse files Browse the repository at this point in the history
Now that the `SystemAccessControl` can provide multiple filtering and
masking expressions, there's no reason for the `ConnectorAccessControl`
not to follow suit.
  • Loading branch information
ksobolew authored and kokosing committed Apr 3, 2022
1 parent ae66a8b commit 827de57
Show file tree
Hide file tree
Showing 17 changed files with 148 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1165,8 +1165,8 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
CatalogAccessControlEntry entry = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());

if (entry != null) {
entry.getAccessControl().getRowFilter(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName())
.ifPresent(filters::add);
entry.getAccessControl().getRowFilters(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName())
.forEach(filters::add);
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
Expand All @@ -1188,8 +1188,8 @@ public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObj
// connector-provided masks take precedence over global masks
CatalogAccessControlEntry entry = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());
if (entry != null) {
entry.getAccessControl().getColumnMask(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName(), columnName, type)
.ifPresent(masks::add);
entry.getAccessControl().getColumnMasks(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName(), columnName, type)
.forEach(masks::add);
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.QualifiedObjectName;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogSchemaName;
Expand All @@ -26,6 +27,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -445,21 +447,21 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
checkArgument(context == null, "context must be null");
if (accessControl.getRowFilters(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName())).isEmpty()) {
return Optional.empty();
return ImmutableList.of();
}
throw new TrinoException(NOT_SUPPORTED, "Row filtering not supported");
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
checkArgument(context == null, "context must be null");
if (accessControl.getColumnMasks(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) {
return Optional.empty();
return ImmutableList.of();
}
throw new TrinoException(NOT_SUPPORTED, "Column masking not supported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.connector;

import com.google.common.collect.ImmutableList;
import io.trino.plugin.base.security.AllowAllAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaTableName;
Expand All @@ -23,6 +24,7 @@
import io.trino.spi.type.Type;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -120,15 +122,19 @@ public void checkCanRevokeTablePrivilege(ConnectorSecurityContext context, Privi
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.ofNullable(rowFilters.apply(tableName));
return Optional.ofNullable(rowFilters.apply(tableName))
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.ofNullable(columnMasks.apply(tableName, columnName));
return Optional.ofNullable(columnMasks.apply(tableName, columnName))
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

public void grantSchemaPrivileges(String schemaName, Set<Privilege> privileges, TrinoPrincipal grantee, boolean grantOption)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ public void checkCanSetSystemSessionProperty(SystemSecurityContext context, Stri
accessControlManager.addCatalogAccessControl(new CatalogName("catalog"), new ConnectorAccessControl()
{
@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type)
{
return Optional.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask"));
return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask"));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableSet;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.Type;
import org.testng.annotations.Test;

import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden;
Expand All @@ -22,7 +26,10 @@ public class TestInjectedConnectorAccessControl
{
@Test
public void testEverythingImplemented()
throws NoSuchMethodException
{
assertAllMethodsOverridden(ConnectorAccessControl.class, InjectedConnectorAccessControl.class);
assertAllMethodsOverridden(ConnectorAccessControl.class, InjectedConnectorAccessControl.class, ImmutableSet.of(
InjectedConnectorAccessControl.class.getMethod("getRowFilter", ConnectorSecurityContext.class, SchemaTableName.class),
InjectedConnectorAccessControl.class.getMethod("getColumnMask", ConnectorSecurityContext.class, SchemaTableName.class, String.class, Type.class)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -74,6 +75,7 @@
import static io.trino.spi.security.AccessDeniedException.denyShowTables;
import static io.trino.spi.security.AccessDeniedException.denyTruncateTable;
import static io.trino.spi.security.AccessDeniedException.denyUpdateTableColumns;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;

public interface ConnectorAccessControl
Expand Down Expand Up @@ -600,22 +602,51 @@ default void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sch
* The filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the filter, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getRowFilters(ConnectorSecurityContext, SchemaTableName)} instead
*/
@Deprecated
default Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
}

/**
* Get row filters associated with the given table and identity.
* <p>
* Each filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the list of filters, or empty list if not applicable
*/
default List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return emptyList();
}

/**
* Get a column mask associated with the given table, column and identity.
* <p>
* The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the mask, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getColumnMasks(ConnectorSecurityContext, SchemaTableName, String, Type)} instead
*/
@Deprecated
default Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}

/**
* Get column masks associated with the given table, column and identity.
* <p>
* Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the list of masks, or empty list if not applicable
*/
default List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return emptyList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import javax.inject.Inject;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -493,18 +494,18 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.getRowFilter(context, tableName);
return delegate.getRowFilters(context, tableName);
}
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.getColumnMask(context, tableName, columnName, type);
return delegate.getColumnMasks(context, tableName, columnName, type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.base.security;

import com.google.common.collect.ImmutableList;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaRoutineName;
Expand All @@ -22,6 +23,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -316,14 +318,14 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
return ImmutableList.of();
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
return ImmutableList.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.base.security;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.plugin.base.CatalogName;
import io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege;
Expand All @@ -31,10 +32,10 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.DELETE;
import static io.trino.plugin.base.security.TableAccessControlRule.TablePrivilege.GRANT_SELECT;
Expand Down Expand Up @@ -591,33 +592,37 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) {
return Optional.empty();
return ImmutableList.of();
}

ConnectorIdentity identity = context.getIdentity();
return tableRules.stream()
.filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName))
.map(rule -> rule.getFilter(identity.getUser(), catalogName, tableName.getSchemaName()))
.findFirst()
.flatMap(Function.identity());
.flatMap(Optional::stream)
// we return the first one we find
.limit(1)
.collect(toImmutableList());
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) {
return Optional.empty();
return ImmutableList.of();
}

ConnectorIdentity identity = context.getIdentity();
return tableRules.stream()
.filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName))
.map(rule -> rule.getColumnMask(identity.getUser(), catalogName, tableName.getSchemaName(), columnName))
.findFirst()
.flatMap(Function.identity());
.flatMap(Optional::stream)
// we return the first one we find
.limit(1)
.collect(toImmutableList());
}

private boolean canSetSessionProperty(ConnectorSecurityContext context, String property)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -391,9 +392,21 @@ public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, S
return delegate().getRowFilter(context, tableName);
}

@Override
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return delegate().getRowFilters(context, tableName);
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return delegate().getColumnMask(context, tableName, columnName, type);
}

@Override
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return delegate().getColumnMasks(context, tableName, columnName, type);
}
}
Loading

0 comments on commit 827de57

Please sign in to comment.