Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple filters and masks from access control #11654

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1165,13 +1165,13 @@ public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObje
CatalogAccessControlEntry entry = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());

if (entry != null) {
entry.getAccessControl().getRowFilter(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName())
.ifPresent(filters::add);
entry.getAccessControl().getRowFilters(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName())
.forEach(filters::add);
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
systemAccessControl.getRowFilter(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName())
.ifPresent(filters::add);
systemAccessControl.getRowFilters(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName())
.forEach(filters::add);
}

return filters.build();
Expand All @@ -1188,13 +1188,13 @@ public List<ViewExpression> getColumnMasks(SecurityContext context, QualifiedObj
// connector-provided masks take precedence over global masks
CatalogAccessControlEntry entry = getConnectorAccessControl(context.getTransactionId(), tableName.getCatalogName());
if (entry != null) {
entry.getAccessControl().getColumnMask(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName(), columnName, type)
.ifPresent(masks::add);
entry.getAccessControl().getColumnMasks(entry.toConnectorSecurityContext(context), tableName.asSchemaTableName(), columnName, type)
.forEach(masks::add);
}

for (SystemAccessControl systemAccessControl : getSystemAccessControls()) {
systemAccessControl.getColumnMask(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type)
.ifPresent(masks::add);
systemAccessControl.getColumnMasks(context.toSystemSecurityContext(), tableName.asCatalogSchemaTableName(), columnName, type)
.forEach(masks::add);
}

return masks.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.QualifiedObjectName;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogSchemaName;
Expand All @@ -26,6 +27,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -445,21 +447,21 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
checkArgument(context == null, "context must be null");
if (accessControl.getRowFilters(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName())).isEmpty()) {
return Optional.empty();
return ImmutableList.of();
}
throw new TrinoException(NOT_SUPPORTED, "Row filtering not supported");
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
checkArgument(context == null, "context must be null");
if (accessControl.getColumnMasks(securityContext, new QualifiedObjectName(catalogName, tableName.getSchemaName(), tableName.getTableName()), columnName, type).isEmpty()) {
return Optional.empty();
return ImmutableList.of();
}
throw new TrinoException(NOT_SUPPORTED, "Column masking not supported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public static BytecodeNode generateLambda(
compiledLambda.getParameterTypes().stream()
.skip(captureExpressions.size() + 1) // skip capture variables and ConnectorSession
.map(ParameterizedType::getAsmType)
.collect(toImmutableList()).toArray(new Type[0]));
.toArray(Type[]::new));

block.append(
invokeDynamic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.connector;

import com.google.common.collect.ImmutableList;
import io.trino.plugin.base.security.AllowAllAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaTableName;
Expand All @@ -23,6 +24,7 @@
import io.trino.spi.type.Type;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -120,15 +122,19 @@ public void checkCanRevokeTablePrivilege(ConnectorSecurityContext context, Privi
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.ofNullable(rowFilters.apply(tableName));
return Optional.ofNullable(rowFilters.apply(tableName))
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.ofNullable(columnMasks.apply(tableName, columnName));
return Optional.ofNullable(columnMasks.apply(tableName, columnName))
.map(ImmutableList::of)
.orElseGet(ImmutableList::of);
}

public void grantSchemaPrivileges(String schemaName, Set<Privilege> privileges, TrinoPrincipal grantee, boolean grantOption)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ public SystemAccessControl create(Map<String, String> config)
return new SystemAccessControl()
{
@Override
public Optional<ViewExpression> getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String column, Type type)
public List<ViewExpression> getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String column, Type type)
{
return Optional.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "system mask"));
return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "system mask"));
}

@Override
Expand All @@ -224,9 +224,9 @@ public void checkCanSetSystemSessionProperty(SystemSecurityContext context, Stri
accessControlManager.addCatalogAccessControl(new CatalogName("catalog"), new ConnectorAccessControl()
{
@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type)
{
return Optional.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask"));
return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask"));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
*/
package io.trino.security;

import com.google.common.collect.ImmutableSet;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.Type;
import org.testng.annotations.Test;

import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden;
Expand All @@ -22,7 +26,10 @@ public class TestInjectedConnectorAccessControl
{
@Test
public void testEverythingImplemented()
throws NoSuchMethodException
{
assertAllMethodsOverridden(ConnectorAccessControl.class, InjectedConnectorAccessControl.class);
assertAllMethodsOverridden(ConnectorAccessControl.class, InjectedConnectorAccessControl.class, ImmutableSet.of(
InjectedConnectorAccessControl.class.getMethod("getRowFilter", ConnectorSecurityContext.class, SchemaTableName.class),
InjectedConnectorAccessControl.class.getMethod("getColumnMask", ConnectorSecurityContext.class, SchemaTableName.class, String.class, Type.class)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -74,6 +75,7 @@
import static io.trino.spi.security.AccessDeniedException.denyShowTables;
import static io.trino.spi.security.AccessDeniedException.denyTruncateTable;
import static io.trino.spi.security.AccessDeniedException.denyUpdateTableColumns;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;

public interface ConnectorAccessControl
Expand Down Expand Up @@ -600,22 +602,51 @@ default void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sch
* The filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the filter, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getRowFilters(ConnectorSecurityContext, SchemaTableName)} instead
*/
@Deprecated
default Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
}

/**
* Get row filters associated with the given table and identity.
* <p>
* Each filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the list of filters, or empty list if not applicable
*/
default List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return emptyList();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new method should delegate to the old method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. As I was removing them I realized the same, not sure how it happened

}

/**
* Get a column mask associated with the given table, column and identity.
* <p>
* The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the mask, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getColumnMasks(ConnectorSecurityContext, SchemaTableName, String, Type)} instead
*/
@Deprecated
default Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}

/**
* Get column masks associated with the given table, column and identity.
* <p>
* Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the list of masks, or empty list if not applicable
*/
default List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return emptyList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.security.Principal;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -823,25 +824,54 @@ default void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityC
* The filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the filter, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getRowFilters(SystemSecurityContext, CatalogSchemaTableName)} instead
*/
@Deprecated
default Optional<ViewExpression> getRowFilter(SystemSecurityContext context, CatalogSchemaTableName tableName)
{
return Optional.empty();
}

/**
* Get row filters associated with the given table and identity.
* <p>
* Each filter must be a scalar SQL expression of boolean type over the columns in the table.
*
* @return the list of filters, or empty list if not applicable
*/
default List<ViewExpression> getRowFilters(SystemSecurityContext context, CatalogSchemaTableName tableName)
{
return getRowFilter(context, tableName).map(List::of).orElseGet(List::of);
}

/**
* Get a column mask associated with the given table, column and identity.
* <p>
* The mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the mask, or {@link Optional#empty()} if not applicable
* @deprecated use {@link #getColumnMasks(SystemSecurityContext, CatalogSchemaTableName, String, Type)} instead
*/
@Deprecated
default Optional<ViewExpression> getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
}

/**
* Get column masks associated with the given table, column and identity.
* <p>
* Each mask must be a scalar SQL expression of a type coercible to the type of the column being masked. The expression
* must be written in terms of columns in the table.
*
* @return the list of masks, or empty list if not applicable
*/
default List<ViewExpression> getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type)
{
return getColumnMask(context, tableName, columnName, type).map(List::of).orElseGet(List::of);
}

/**
* @return the event listeners provided by this system access control
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import javax.inject.Inject;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -493,18 +494,18 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.getRowFilter(context, tableName);
return delegate.getRowFilters(context, tableName);
}
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
return delegate.getColumnMask(context, tableName, columnName, type);
return delegate.getColumnMasks(context, tableName, columnName, type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.base.security;

import com.google.common.collect.ImmutableList;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSecurityContext;
import io.trino.spi.connector.SchemaRoutineName;
Expand All @@ -22,6 +23,7 @@
import io.trino.spi.security.ViewExpression;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -316,14 +318,14 @@ public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, Sche
}

@Override
public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName)
public List<ViewExpression> getRowFilters(ConnectorSecurityContext context, SchemaTableName tableName)
{
return Optional.empty();
return ImmutableList.of();
}

@Override
public Optional<ViewExpression> getColumnMask(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
public List<ViewExpression> getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String columnName, Type type)
{
return Optional.empty();
return ImmutableList.of();
}
}
Loading