Skip to content

Commit

Permalink
Allow returning multiple filters and masks in SystemAccessControl
Browse files Browse the repository at this point in the history
The engine is perfectly capable of processing multiple row filter and
column mask expressions, given that it supports running multiple system
access controls and each of them can provide an expression. See:
`io.trino.security.AccessControlManager#getColumnMasks` and
`io.trino.security.AccessControlManager#getRowFilters`. So inability to
provide more than one expression per access control looks like an
artificial restriction. Case in point: the file-based access control,
which is also already capable of providing multiple expressions for both
row filters and column masks (it just picked the first one and discarded
the rest).
  • Loading branch information
ksobolew committed Apr 1, 2022
1 parent 8018c92 commit 212b55d
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1170,8 +1170,8 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
systemAccessControl.getRowFilter(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName())
.ifPresent(filters::add);
systemAccessControl.getRowFilters(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName())
.forEach(filters::add);
}

return filters.build();
Expand All @@ -1193,8 +1193,8 @@ public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObj
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
systemAccessControl.getColumnMask(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type)
.ifPresent(masks::add);
systemAccessControl.getColumnMasks(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type)
.forEach(masks::add);
}

return masks.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ public SystemAccessControl create(Map<String, String> config)
return new SystemAccessControl()
{
@Override
public Optional<ViewExpression> getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String column, Type type)
public List<ViewExpression> getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String column, Type type)
{
return Optional.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "system mask"));
return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "system mask"));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.security.Principal;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -823,25 +824,54 @@ default void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityC
* The filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the filter, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getRowFilters(SystemSecurityContext, CatalogSchemaTableName)} instead
*/
@Deprecated
default Optional<ViewExpression> getRowFilter(SystemSecurityContext context, CatalogSchemaTableName tableName)
{
return Optional.empty();
}

/**
* Get row filters associated with the given table and identity.
* <p>
* Each filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the list of filters, or empty list if not applicable
*/
default List<ViewExpression> getRowFilters(SystemSecurityContext context, CatalogSchemaTableName tableName)
{
return getRowFilter(context, tableName).map(List::of).orElseGet(List::of);
}

/**
* Get a column mask associated with the given table, column and identity.
* <p>
* The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the mask, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getColumnMasks(SystemSecurityContext, CatalogSchemaTableName, String, Type)} instead
*/
@Deprecated
default Optional<ViewExpression> getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}

/**
* Get column masks associated with the given table, column and identity.
* <p>
* Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the list of masks, or empty list if not applicable
*/
default List<ViewExpression> getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return getColumnMask(context, tableName, columnName, type).map(List::of).orElseGet(List::of);
}

/**
* @return the event listeners provided by this system access control
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@

import java.security.Principal;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;

public class AllowAllSystemAccessControl
Expand Down Expand Up @@ -435,14 +437,14 @@ public Iterable<EventListener> getEventListeners()
}

@Override
public Optional<ViewExpression> getRowFilter(SystemSecurityContext context, CatalogSchemaTableName tableName)
public List<ViewExpression> getRowFilters(SystemSecurityContext context, CatalogSchemaTableName tableName)
{
return Optional.empty();
return emptyList();
}

@Override
public Optional<ViewExpression> getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
return emptyList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Pattern;

import static com.google.common.base.Suppliers.memoizeWithExpiration;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.trino.plugin.base.security.CatalogAccessControlRule.AccessMode.ALL;
Expand Down Expand Up @@ -950,35 +950,39 @@ public Iterable<EventListener> getEventListeners()
}

@Override
public Optional<ViewExpression> getRowFilter(SystemSecurityContext context, CatalogSchemaTableName table)
public List<ViewExpression> getRowFilters(SystemSecurityContext context, CatalogSchemaTableName table)
{
SchemaTableName tableName = table.getSchemaTableName();
if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) {
return Optional.empty();
return ImmutableList.of();
}

Identity identity = context.getIdentity();
return tableRules.stream()
.filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table))
.map(rule -> rule.getFilter(identity.getUser(), table.getCatalogName(), tableName.getSchemaName()))
.findFirst()
.flatMap(Function.identity());
.flatMap(Optional::stream)
// we return the first one we find
.limit(1)
.collect(toImmutableList());
}

@Override
public Optional<ViewExpression> getColumnMask(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type)
public List<ViewExpression> getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName table, String columnName, Type type)
{
SchemaTableName tableName = table.getSchemaTableName();
if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) {
return Optional.empty();
return ImmutableList.of();
}

Identity identity = context.getIdentity();
return tableRules.stream()
.filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table))
.map(rule -> rule.getColumnMask(identity.getUser(), table.getCatalogName(), table.getSchemaTableName().getSchemaName(), columnName))
.findFirst()
.flatMap(Function.identity());
.flatMap(Optional::stream)
// we return the first one we find
.limit(1)
.collect(toImmutableList());
}

private boolean checkAnyCatalogAccess(SystemSecurityContext context, String catalogName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.security.Principal;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -485,9 +486,21 @@ public Optional<ViewExpression> getRowFilter(SystemSecurityContext context, Cata
return delegate().getRowFilter(context, tableName);
}

@Override
public List<ViewExpression> getRowFilters(SystemSecurityContext context, CatalogSchemaTableName tableName)
{
return delegate().getRowFilters(context, tableName);
}

@Override
public Optional<ViewExpression> getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return delegate().getColumnMask(context, tableName, columnName, type);
}

@Override
public List<ViewExpression> getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return delegate().getColumnMasks(context, tableName, columnName, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
*/
package io.trino.plugin.base.security;

import com.google.common.collect.ImmutableSet;
import io.trino.spi.connector.CatalogSchemaTableName;
import io.trino.spi.security.SystemAccessControl;
import io.trino.spi.security.SystemSecurityContext;
import io.trino.spi.type.Type;
import org.testng.annotations.Test;

import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden;
Expand All @@ -24,6 +28,8 @@ public class TestAllowAllSystemAccessControl
public void testEverythingImplemented()
throws ReflectiveOperationException
{
assertAllMethodsOverridden(SystemAccessControl.class, AllowAllSystemAccessControl.class);
assertAllMethodsOverridden(SystemAccessControl.class, AllowAllSystemAccessControl.class, ImmutableSet.of(
AllowAllSystemAccessControl.class.getMethod("getRowFilter", SystemSecurityContext.class, CatalogSchemaTableName.class),
AllowAllSystemAccessControl.class.getMethod("getColumnMask", SystemSecurityContext.class, CatalogSchemaTableName.class, String.class, Type.class)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.base.security;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
Expand All @@ -28,6 +29,7 @@
import io.trino.spi.security.SystemSecurityContext;
import io.trino.spi.security.TrinoPrincipal;
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;
import org.assertj.core.api.ThrowableAssert.ThrowingCallable;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
Expand All @@ -37,6 +39,7 @@
import java.io.File;
import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand All @@ -54,7 +57,6 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.util.Files.newTemporaryFile;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;

public class TestFileBasedSystemAccessControl
{
Expand Down Expand Up @@ -1267,23 +1269,23 @@ public void testGetColumnMask()
SystemAccessControl accessControl = newFileBasedSystemAccessControl("file-based-system-access-table.json");

assertEquals(
accessControl.getColumnMask(
accessControl.getColumnMasks(
ALICE,
new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"),
"masked",
VARCHAR),
Optional.empty());
ImmutableList.of());

assertViewExpressionEquals(
accessControl.getColumnMask(
accessControl.getColumnMasks(
CHARLIE,
new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"),
"masked",
VARCHAR),
new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask'"));

assertViewExpressionEquals(
accessControl.getColumnMask(
accessControl.getColumnMasks(
CHARLIE,
new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"),
"masked_with_user",
Expand All @@ -1297,22 +1299,22 @@ public void testGetRowFilter()
SystemAccessControl accessControl = newFileBasedSystemAccessControl("file-based-system-access-table.json");

assertEquals(
accessControl.getRowFilter(ALICE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")),
Optional.empty());
accessControl.getRowFilters(ALICE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")),
ImmutableList.of());

assertViewExpressionEquals(
accessControl.getRowFilter(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")),
accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns")),
new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter')"));

assertViewExpressionEquals(
accessControl.getRowFilter(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")),
accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")),
new ViewExpression("filter-user", Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter-with-user')"));
}

private static void assertViewExpressionEquals(Optional<ViewExpression> result, ViewExpression expected)
private static void assertViewExpressionEquals(List<ViewExpression> result, ViewExpression expected)
{
assertTrue(result.isPresent());
ViewExpression actual = result.get();
assertEquals(result.size(), 1);
ViewExpression actual = result.get(0);
assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity");
assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog");
assertEquals(actual.getSchema(), expected.getSchema(), "Schema");
Expand All @@ -1324,6 +1326,8 @@ public void testEverythingImplemented()
throws NoSuchMethodException
{
assertAllMethodsOverridden(SystemAccessControl.class, FileBasedSystemAccessControl.class, ImmutableSet.of(
FileBasedSystemAccessControl.class.getMethod("getRowFilter", SystemSecurityContext.class, CatalogSchemaTableName.class),
FileBasedSystemAccessControl.class.getMethod("getColumnMask", SystemSecurityContext.class, CatalogSchemaTableName.class, String.class, Type.class),
FileBasedSystemAccessControl.class.getMethod("checkCanViewQueryOwnedBy", SystemSecurityContext.class, Identity.class),
FileBasedSystemAccessControl.class.getMethod("filterViewQueryOwnedBy", SystemSecurityContext.class, Collection.class),
FileBasedSystemAccessControl.class.getMethod("checkCanKillQueryOwnedBy", SystemSecurityContext.class, Identity.class)));
Expand Down

0 comments on commit 212b55d

Please sign in to comment.