From 16c3e44b661b5a874a87c4128e0d8fa11fbec4f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Grzegorz=20Kokosi=C5=84ski?= Date: Sat, 11 Mar 2023 08:02:20 +0100 Subject: [PATCH] Allow ViewExpression use session user explicitly Before the change if access control is not returning any dedicated user to evaluate the view expression they just pass the session user. However for the engine it is not clear that is a session user and so engine needs to retrieve groups for that user again and possibly some session roles are lost. After this change access control may decide to return empty identity. That would be in line of view SECURITY INVOKER. Then engine can simply reuse the session identity. --- .../trino/sql/analyzer/StatementAnalyzer.java | 28 +++-- .../security/TestAccessControlManager.java | 4 +- .../sql/planner/TestMaterializedViews.java | 2 +- .../io/trino/sql/query/TestColumnMask.java | 114 +++++++++--------- .../query/TestFilterInaccessibleColumns.java | 47 +++++++- .../io/trino/sql/query/TestRowFilter.java | 70 +++++------ .../io/trino/spi/security/ViewExpression.java | 18 ++- .../CatalogTableAccessControlRule.java | 8 +- .../base/security/FileBasedAccessControl.java | 4 +- .../FileBasedSystemAccessControl.java | 4 +- .../base/security/TableAccessControlRule.java | 8 +- ...seFileBasedConnectorAccessControlTest.java | 8 +- .../BaseFileBasedSystemAccessControlTest.java | 14 +-- .../io/trino/security/TestAccessControl.java | 48 +++++++- 14 files changed, 240 insertions(+), 137 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index b373ab0e5319..e38d1ced99c6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -2279,7 +2279,7 @@ private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, Relat private void analyzeCheckConstraints(Table table, QualifiedObjectName name, Scope accessControlScope, List constraints) { for (String constraint : constraints) { - ViewExpression expression = new ViewExpression(session.getIdentity().getUser(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), constraint); + ViewExpression expression = new ViewExpression(Optional.empty(), Optional.of(name.getCatalogName()), Optional.of(name.getSchemaName()), constraint); analyzeCheckConstraint(table, name, accessControlScope, expression); } } @@ -4663,9 +4663,11 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje ExpressionAnalysis expressionAnalysis; try { - Identity filterIdentity = Identity.forUser(filter.getIdentity()) - .withGroups(groupProvider.getGroups(filter.getIdentity())) - .build(); + Identity filterIdentity = filter.getSecurityIdentity() + .map(filterUser -> Identity.forUser(filterUser) + .withGroups(groupProvider.getGroups(filterUser)) + .build()) + .orElseGet(session::getIdentity); expressionAnalysis = ExpressionAnalyzer.analyzeExpression( createViewSession(filter.getCatalog(), filter.getSchema(), filterIdentity, session.getPath()), // TODO: path should be included in row filter plannerContext, @@ -4714,11 +4716,13 @@ private void analyzeCheckConstraint(Table table, QualifiedObjectName name, Scope ExpressionAnalysis expressionAnalysis; try { - Identity filterIdentity = Identity.forUser(constraint.getIdentity()) - .withGroups(groupProvider.getGroups(constraint.getIdentity())) - .build(); + Identity constraintIdentity = constraint.getSecurityIdentity() + .map(user -> Identity.forUser(user) + .withGroups(groupProvider.getGroups(user)) + .build()) + .orElseGet(session::getIdentity); expressionAnalysis = ExpressionAnalyzer.analyzeExpression( - createViewSession(constraint.getCatalog(), constraint.getSchema(), filterIdentity, session.getPath()), + createViewSession(constraint.getCatalog(), constraint.getSchema(), constraintIdentity, session.getPath()), plannerContext, statementAnalyzerFactory, accessControl, @@ -4777,9 +4781,11 @@ private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObj verifyNoAggregateWindowOrGroupingFunctions(session, metadata, expression, format("Column mask for '%s.%s'", table.getName(), column)); try { - Identity maskIdentity = Identity.forUser(mask.getIdentity()) - .withGroups(groupProvider.getGroups(mask.getIdentity())) - .build(); + Identity maskIdentity = mask.getSecurityIdentity() + .map(maskUser -> Identity.forUser(maskUser) + .withGroups(groupProvider.getGroups(maskUser)) + .build()) + .orElseGet(session::getIdentity); expressionAnalysis = ExpressionAnalyzer.analyzeExpression( createViewSession(mask.getCatalog(), mask.getSchema(), maskIdentity, session.getPath()), // TODO: path should be included in row filter plannerContext, diff --git a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java index 07ce6c40feca..af1c2534a755 100644 --- a/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java +++ b/core/trino-main/src/test/java/io/trino/security/TestAccessControlManager.java @@ -231,7 +231,7 @@ public SystemAccessControl create(Map config) @Override public List getColumnMasks(SystemSecurityContext context, CatalogSchemaTableName tableName, String column, Type type) { - return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "system mask")); + return ImmutableList.of(new ViewExpression(Optional.of("user"), Optional.empty(), Optional.empty(), "system mask")); } @Override @@ -249,7 +249,7 @@ public void checkCanSetSystemSessionProperty(SystemSecurityContext context, Stri @Override public List getColumnMasks(ConnectorSecurityContext context, SchemaTableName tableName, String column, Type type) { - return ImmutableList.of(new ViewExpression("user", Optional.empty(), Optional.empty(), "connector mask")); + return ImmutableList.of(new ViewExpression(Optional.of("user"), Optional.empty(), Optional.empty(), "connector mask")); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java index 19fece81b2d5..81bd4d050432 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestMaterializedViews.java @@ -236,7 +236,7 @@ public void testMaterializedViewWithCasts() new QualifiedObjectName(TEST_CATALOG_NAME, SCHEMA, "materialized_view_with_casts"), "a", "user", - new ViewExpression("user", Optional.empty(), Optional.empty(), "a + 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "a + 1")); assertPlan("SELECT * FROM materialized_view_with_casts", anyTree( project( diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index 3b268a6a4462..9c176068c678 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -194,7 +194,7 @@ public void testSimpleMask() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "-custkey")); assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '-370'"); accessControl.reset(); @@ -202,7 +202,7 @@ public void testSimpleMask() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "NULL")); assertThat(assertions.query("SELECT custkey FROM orders WHERE orderkey = 1")).matches("VALUES CAST(NULL AS BIGINT)"); } @@ -214,7 +214,7 @@ public void testConditionalMask() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "IF (orderkey < 2, null, -custkey)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "IF (orderkey < 2, null, -custkey)")); assertThat(assertions.query("SELECT custkey FROM orders LIMIT 2")) .matches("VALUES (NULL), CAST('-781' AS BIGINT)"); } @@ -227,13 +227,13 @@ public void testMultipleMasksOnDifferentColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "-custkey")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "'X'")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "'X'")); assertThat(assertions.query("SELECT custkey, orderstatus FROM orders WHERE orderkey = 1")) .matches("VALUES (BIGINT '-370', 'X')"); @@ -247,13 +247,13 @@ public void testReferenceInUsingClause() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "lineitem"), "orderkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "IF(orderkey = 1, -orderkey)")); assertThat(assertions.query("SELECT count(*) FROM orders JOIN lineitem USING (orderkey)")).matches("VALUES BIGINT '6'"); } @@ -266,7 +266,7 @@ public void testCoercibleType() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "CAST(clerk AS VARCHAR(5))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "CAST(clerk AS VARCHAR(5))")); assertThat(assertions.query("SELECT clerk FROM orders WHERE orderkey = 1")).matches("VALUES CAST('Clerk' AS VARCHAR(15))"); } @@ -279,7 +279,7 @@ public void testSubquery() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation)")); assertThat(assertions.query("SELECT clerk FROM orders WHERE orderkey = 1")).matches("VALUES CAST('VIETNAM' AS VARCHAR(15))"); // correlated @@ -288,7 +288,7 @@ public void testSubquery() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation WHERE nationkey = orderkey)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT cast(max(name) AS VARCHAR(15)) FROM nation WHERE nationkey = orderkey)")); assertThat(assertions.query("SELECT clerk FROM orders WHERE orderkey = 1")).matches("VALUES CAST('ARGENTINA' AS VARCHAR(15))"); } @@ -301,17 +301,17 @@ public void testMaterializedView() new QualifiedObjectName(MOCK_CATALOG, "default", "nation_fresh_materialized_view"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "reverse(name)")); accessControl.columnMask( new QualifiedObjectName(MOCK_CATALOG, "default", "nation_materialized_view"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "reverse(name)")); accessControl.columnMask( new QualifiedObjectName(MOCK_CATALOG, "default", "materialized_view_with_casts"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -344,7 +344,7 @@ public void testView() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.empty(), Optional.empty(), "reverse(name)")); + new ViewExpression(Optional.of(VIEW_OWNER), Optional.empty(), Optional.empty(), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -359,7 +359,7 @@ public void testView() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); + new ViewExpression(Optional.of(VIEW_OWNER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -374,7 +374,7 @@ public void testView() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); assertThat(assertions.query( Session.builder(SESSION) @@ -389,7 +389,7 @@ public void testView() new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view"), "name", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "reverse(name)")); assertThat(assertions.query("SELECT name FROM mock.default.nation_view WHERE nationkey = 1")).matches("VALUES CAST('ANITNEGRA' AS VARCHAR(25))"); } @@ -401,7 +401,7 @@ public void testTableReferenceInWithClause() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "custkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "-custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "-custkey")); assertThat(assertions.query("WITH t AS (SELECT custkey FROM orders WHERE orderkey = 1) SELECT * FROM t")).matches("VALUES BIGINT '-370'"); } @@ -413,7 +413,7 @@ public void testOtherSchema() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer)")); // count is 15000 only when evaluating against sf1 + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer)")); // count is 15000 only when evaluating against sf1 assertThat(assertions.query("SELECT max(orderkey) FROM orders")).matches("VALUES BIGINT '150000'"); } @@ -425,13 +425,13 @@ public void testDifferentIdentity() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "100")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "100")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT sum(orderkey) FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT sum(orderkey) FROM orders)")); assertThat(assertions.query("SELECT max(orderkey) FROM orders")).matches("VALUES BIGINT '1500000'"); } @@ -444,7 +444,7 @@ public void testRecursion() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessageMatching(".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); @@ -455,7 +455,7 @@ public void testRecursion() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM local.tiny.orders)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM local.tiny.orders)")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessageMatching(".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); @@ -466,13 +466,13 @@ public void testRecursion() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessageMatching(".*\\QColumn mask for 'local.tiny.orders.orderkey' is recursive\\E.*"); @@ -486,7 +486,7 @@ public void testLimitedScope() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "customer"), "custkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey")); assertThatThrownBy(() -> assertions.query( "SELECT (SELECT min(custkey) FROM customer WHERE customer.custkey = orders.custkey) FROM orders")) .hasMessage("line 1:34: Invalid column mask for 'local.tiny.customer.custkey': Column 'orderkey' cannot be resolved"); @@ -500,7 +500,7 @@ public void testSqlInjection() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), "name", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT name FROM region WHERE regionkey = 0)")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "(SELECT name FROM region WHERE regionkey = 0)")); assertThat(assertions.query( "WITH region(regionkey, name) AS (VALUES (0, 'ASIA'))" + "SELECT name FROM nation ORDER BY name LIMIT 1")) @@ -516,7 +516,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Invalid column mask for 'local.tiny.orders.orderkey': mismatched input '$'. Expecting: "); @@ -527,7 +527,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Invalid column mask for 'local.tiny.orders.orderkey': Column 'unknown_column' cannot be resolved"); @@ -538,7 +538,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "'foo'")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "'foo'")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Expected column mask for 'local.tiny.orders.orderkey' to be of type bigint, but was varchar(3)"); @@ -549,7 +549,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:10: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [count(*)]"); @@ -560,7 +560,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:22: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [row_number() OVER ()]"); @@ -571,7 +571,7 @@ public void testInvalidMasks() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("line 1:20: Column mask for 'orders.orderkey' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]"); @@ -585,7 +585,7 @@ public void testShowStats() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "7")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "7")); assertThat(assertions.query("SHOW STATS FOR (SELECT * FROM orders)")) .containsAll(""" @@ -616,7 +616,7 @@ public void testJoin() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey + 1")); + new ViewExpression(Optional.of(USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey + 1")); assertThat(assertions.query("SELECT count(*) FROM orders JOIN orders USING (orderkey)")).matches("VALUES BIGINT '15000'"); } @@ -629,7 +629,7 @@ public void testColumnMaskingUsingRestrictedColumn() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderkey", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "custkey")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "custkey")); assertThatThrownBy(() -> assertions.query("SELECT orderkey FROM orders")) .hasMessage("Access Denied: Cannot select from columns [orderkey, custkey] in table or view local.tiny.orders"); } @@ -642,7 +642,7 @@ public void testInsertWithColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThatThrownBy(() -> assertions.query("INSERT INTO orders SELECT * FROM orders")) .hasMessage("Insert into table with column masks is not supported"); } @@ -655,7 +655,7 @@ public void testDeleteWithColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThatThrownBy(() -> assertions.query("DELETE FROM orders")) .hasMessage("line 1:1: Delete from table with column mask"); } @@ -668,7 +668,7 @@ public void testUpdateWithColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThatThrownBy(() -> assertions.query("UPDATE orders SET clerk = 'X'")) .hasMessage("line 1:1: Updating a table with column masks is not supported"); assertThatThrownBy(() -> assertions.query("UPDATE orders SET orderkey = -orderkey")) @@ -687,7 +687,7 @@ public void testNotReferencedAndDeniedColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "clerk")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "clerk")); assertThat(assertions.query("SELECT orderkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '1'"); // mask on long column @@ -697,7 +697,7 @@ public void testNotReferencedAndDeniedColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "totalprice", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "totalprice")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "totalprice")); assertThat(assertions.query("SELECT orderkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '1'"); // mask on not used varchar column with subquery masking @@ -708,7 +708,7 @@ public void testNotReferencedAndDeniedColumnMasking() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "(SELECT orderstatus FROM local.tiny.orders)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "(SELECT orderstatus FROM local.tiny.orders)")); assertThat(assertions.query("SELECT orderkey FROM orders WHERE orderkey = 1")).matches("VALUES BIGINT '1'"); } @@ -720,7 +720,7 @@ public void testColumnMaskWithHiddenColumns() new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation_with_hidden_column"), "name", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "'POLAND'")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "'POLAND'")); assertions.query("SELECT * FROM mock.tiny.nation_with_hidden_column WHERE nationkey = 1") .assertThat() @@ -754,19 +754,19 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(comment,'(password: [^ ]+)','password: ****') as varchar(79))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(comment,'(password: [^ ]+)','password: ****') as varchar(79))")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '***', clerk)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(comment,'(country: [^ ]+)') IN ('country: 1'), '***', clerk)")); assertThat(assertions.query(query)).matches(expected); @@ -777,13 +777,13 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); assertThat(assertions.query(query)).matches(expected); @@ -794,19 +794,19 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(password: [^ ]+)','password: ****') as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '*', orderstatus)")); accessControl.columnMask( new QualifiedObjectName(TEST_CATALOG_NAME, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'(country: [^ ]+)') IN ('country: 1'), '***', comment)")); assertThat(assertions.query(query)).matches(expected); @@ -817,13 +817,13 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast(regexp_replace(clerk,'(Clerk#)','***#') as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); assertThat(assertions.query(query)) .matches("VALUES (CAST('***' as varchar(79)), 'O', CAST('***#000000951' as varchar(15)))"); @@ -835,19 +835,19 @@ public void testMultipleMasksUsingOtherMaskedColumns() new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "clerk", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "cast('###' as varchar(15))")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "cast('###' as varchar(15))")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "orderstatus", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '*', orderstatus)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '*', orderstatus)")); accessControl.columnMask( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), "comment", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(regexp_extract(clerk,'([1-9]+)') IN ('951'), '***', comment)")); assertThat(assertions.query(query)) .matches("VALUES (CAST('***' as varchar(79)), '*', CAST('###' as varchar(15)))"); @@ -861,7 +861,7 @@ public void testColumnAliasing() new QualifiedObjectName(MOCK_CATALOG, "default", "view_with_nested"), "nested", USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "if(id = 0, nested)")); + new ViewExpression(Optional.of(USER), Optional.empty(), Optional.empty(), "if(id = 0, nested)")); assertThat(assertions.query("SELECT nested[1] FROM mock.default.view_with_nested")) .matches("VALUES 1, NULL"); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java b/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java index de57d9eabecc..a6df900756e0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestFilterInaccessibleColumns.java @@ -153,7 +153,7 @@ public void testRowFilterWithAccessToInaccessibleColumn() { accessControl.rowFilter(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), USER, - new ViewExpression(ADMIN, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); + new ViewExpression(Optional.of(ADMIN), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); accessControl.deny(privilege(USER, "nation.comment", SELECT_COLUMN)); assertThat(assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .matches("VALUES (BIGINT '6', CAST('FRANCE' AS VARCHAR(25)), BIGINT '3')"); @@ -164,19 +164,34 @@ public void testRowFilterWithoutAccessToInaccessibleColumn() { accessControl.rowFilter(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), USER, - new ViewExpression(USER, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); + new ViewExpression(Optional.of(USER), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null")); accessControl.deny(privilege(USER, "nation.comment", SELECT_COLUMN)); assertThatThrownBy(() -> assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); } + @Test + public void testRowFilterAsSessionUserOnInaccessibleColumn() + { + accessControl.deny(privilege(USER, "nation.comment", SELECT_COLUMN)); + QualifiedObjectName table = new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"); + ViewExpression filter = new ViewExpression(Optional.empty(), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "comment IS NOT null"); + accessControl.rowFilter(table, ADMIN, filter); + accessControl.rowFilter(table, USER, filter); + + assertThatThrownBy(() -> assertions.query(user(USER), "SELECT * FROM nation WHERE name = 'FRANCE'")) + .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); + assertThat(assertions.query(user(ADMIN), "SELECT * FROM nation WHERE name = 'FRANCE'")) + .matches("VALUES (BIGINT '6', CAST('FRANCE' AS VARCHAR(25)), BIGINT '3', CAST('refully final requests. regular, ironi' AS VARCHAR(152)))"); + } + @Test public void testMaskingOnAccessibleColumn() { accessControl.columnMask(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), "nationkey", USER, - new ViewExpression(ADMIN, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "-nationkey")); + new ViewExpression(Optional.of(ADMIN), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "-nationkey")); assertThat(assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .matches("VALUES (BIGINT '-6',CAST('FRANCE' AS VARCHAR(25)), BIGINT '3', CAST('refully final requests. regular, ironi' AS VARCHAR(152)))"); } @@ -188,7 +203,7 @@ public void testMaskingWithoutAccessToInaccessibleColumn() accessControl.columnMask(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), "comment", USER, - new ViewExpression(USER, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); + new ViewExpression(Optional.of(USER), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); assertThatThrownBy(() -> assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); @@ -201,7 +216,7 @@ public void testMaskingWithAccessToInaccessibleColumn() accessControl.columnMask(new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"), "comment", USER, - new ViewExpression(ADMIN, Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); + new ViewExpression(Optional.of(ADMIN), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 6 THEN 'masked-comment' ELSE comment END")); assertThat(assertions.query("SELECT * FROM nation WHERE name = 'FRANCE'")) .matches("VALUES (CAST('FRANCE' AS VARCHAR(25)), BIGINT '3', CAST('masked-comment' AS VARCHAR(152)))"); @@ -210,6 +225,21 @@ public void testMaskingWithAccessToInaccessibleColumn() .matches("VALUES (CAST('CANADA' AS VARCHAR(25)), BIGINT '1', CAST('eas hang ironic, silent packages. slyly regular packages are furiously over the tithes. fluffily bold' AS VARCHAR(152)))"); } + @Test + public void testMaskingAsSessionUserWithCaseOnInaccessibleColumn() + { + accessControl.deny(privilege(USER, "nation.nationkey", SELECT_COLUMN)); + QualifiedObjectName table = new QualifiedObjectName(TEST_CATALOG_NAME, TINY_SCHEMA_NAME, "nation"); + ViewExpression mask = new ViewExpression(Optional.empty(), Optional.of(TEST_CATALOG_NAME), Optional.of(TINY_SCHEMA_NAME), "CASE nationkey WHEN 3 THEN 'masked-comment' ELSE comment END"); + accessControl.columnMask(table, "comment", ADMIN, mask); + accessControl.columnMask(table, "comment", USER, mask); + + assertThatThrownBy(() -> assertions.query(user(USER), "SELECT * FROM nation WHERE name = 'FRANCE'")) + .hasMessage("Access Denied: Cannot select from columns [nationkey, regionkey, name, comment] in table or view test-catalog.tiny.nation"); + assertThat(assertions.query(user(ADMIN), "SELECT * FROM nation WHERE name = 'CANADA'")) + .matches("VALUES (BIGINT '3', CAST('CANADA' AS VARCHAR(25)), BIGINT '1', CAST('masked-comment' AS VARCHAR(152)))"); + } + @Test public void testPredicateOnInaccessibleColumn() { @@ -257,4 +287,11 @@ public void testFunctionOnInaccessibleColumn() assertThatThrownBy(() -> assertions.query("SELECT * FROM (SELECT concat(name,'-test') FROM nation WHERE name = 'FRANCE')")) .hasMessage("Access Denied: Cannot select from columns [name] in table or view test-catalog.tiny.nation"); } + + private Session user(String user) + { + return Session.builder(assertions.getDefaultSession()) + .setIdentity(Identity.ofUser(user)) + .build(); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java b/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java index aa8b388bf526..ed9053b140b0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestRowFilter.java @@ -156,14 +156,14 @@ public void testSimpleFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey < 10")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '7'"); accessControl.reset(); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "NULL")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '0'"); } @@ -174,12 +174,12 @@ public void testMultipleFilters() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey < 10")); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey > 5")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey > 5")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '2'"); } @@ -191,7 +191,7 @@ public void testCorrelatedSubquery() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "EXISTS (SELECT 1 FROM nation WHERE nationkey = orderkey)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "EXISTS (SELECT 1 FROM nation WHERE nationkey = orderkey)")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '7'"); } @@ -203,7 +203,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.empty(), Optional.empty(), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey = 1")); assertThat(assertions.query( Session.builder(SESSION) @@ -217,7 +217,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), VIEW_OWNER, - new ViewExpression(VIEW_OWNER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); assertThat(assertions.query( Session.builder(SESSION) @@ -231,7 +231,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); Session session = Session.builder(SESSION) .setIdentity(Identity.forUser(RUN_AS_USER).build()) @@ -244,7 +244,7 @@ public void testView() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "nationkey = 1")); assertThat(assertions.query("SELECT name FROM mock.default.nation_view")).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))"); } @@ -255,7 +255,7 @@ public void testTableReferenceInWithClause() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey = 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "orderkey = 1")); assertThat(assertions.query("WITH t AS (SELECT count(*) FROM orders) SELECT * FROM t")).matches("VALUES BIGINT '1'"); } @@ -266,7 +266,7 @@ public void testOtherSchema() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer) = 150000")); // Filter is TRUE only if evaluating against sf1.customer + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("sf1"), "(SELECT count(*) FROM customer) = 150000")); // Filter is TRUE only if evaluating against sf1.customer assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '15000'"); } @@ -277,12 +277,12 @@ public void testDifferentIdentity() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); assertThat(assertions.query("SELECT count(*) FROM orders")).matches("VALUES BIGINT '1'"); } @@ -294,7 +294,7 @@ public void testRecursion() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); @@ -304,7 +304,7 @@ public void testRecursion() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT local.tiny.orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT local.tiny.orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); @@ -313,12 +313,12 @@ public void testRecursion() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), RUN_AS_USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*"); @@ -331,7 +331,7 @@ public void testLimitedScope() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "customer"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 1")); assertThatThrownBy(() -> assertions.query( "SELECT (SELECT min(name) FROM customer WHERE customer.custkey = orders.custkey) FROM orders")) .hasMessage("line 1:31: Invalid row filter for 'local.tiny.customer': Column 'orderkey' cannot be resolved"); @@ -344,7 +344,7 @@ public void testSqlInjection() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "regionkey IN (SELECT regionkey FROM region WHERE name = 'ASIA')")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "regionkey IN (SELECT regionkey FROM region WHERE name = 'ASIA')")); assertThat(assertions.query( "WITH region(regionkey, name) AS (VALUES (0, 'ASIA'), (1, 'ASIA'), (2, 'ASIA'), (3, 'ASIA'), (4, 'ASIA'))" + "SELECT name FROM nation ORDER BY name LIMIT 1")) @@ -359,7 +359,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "$$$")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Invalid row filter for 'local.tiny.orders': mismatched input '$'. Expecting: "); @@ -369,7 +369,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "unknown_column")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Invalid row filter for 'local.tiny.orders': Column 'unknown_column' cannot be resolved"); @@ -379,7 +379,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "1")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "1")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Expected row filter for 'local.tiny.orders' to be of type BOOLEAN, but was integer"); @@ -389,7 +389,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "count(*) > 0")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:10: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [count(*)]"); @@ -399,7 +399,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "row_number() OVER () > 0")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:22: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [row_number() OVER ()]"); @@ -409,7 +409,7 @@ public void testInvalidFilter() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); + new ViewExpression(Optional.empty(), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "grouping(orderkey) = 0")); assertThatThrownBy(() -> assertions.query("SELECT count(*) FROM orders")) .hasMessage("line 1:20: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]"); @@ -422,7 +422,7 @@ public void testShowStats() accessControl.rowFilter( new QualifiedObjectName(LOCAL_CATALOG, "tiny", "orders"), USER, - new ViewExpression(RUN_AS_USER, Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 0")); + new ViewExpression(Optional.of(RUN_AS_USER), Optional.of(LOCAL_CATALOG), Optional.of("tiny"), "orderkey = 0")); assertThat(assertions.query("SHOW STATS FOR (SELECT * FROM tiny.orders)")) .containsAll( @@ -442,7 +442,7 @@ public void testDelete() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertions.query("DELETE FROM mock.tiny.nation WHERE nationkey < 3") @@ -474,7 +474,7 @@ public void testMergeDelete() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertThatThrownBy(() -> assertions.query(""" @@ -507,7 +507,7 @@ public void testUpdate() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertThatThrownBy(() -> assertions.query("UPDATE mock.tiny.nation SET regionkey = regionkey * 2 WHERE nationkey < 3")) @@ -547,7 +547,7 @@ public void testMergeUpdate() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 10")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 10")); // Within allowed row filter assertThatThrownBy(() -> assertions.query(""" @@ -604,7 +604,7 @@ public void testInsert() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey > 100")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey > 100")); // Within allowed row filter assertions.query("INSERT INTO mock.tiny.nation VALUES (101, 'POLAND', 0, 'No comment')") @@ -635,7 +635,7 @@ public void testMergeInsert() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey > 100")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey > 100")); // Within allowed row filter assertThatThrownBy(() -> assertions.query(""" @@ -670,7 +670,7 @@ public void testRowFilterWithHiddenColumns() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation_with_hidden_column"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "nationkey < 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey < 1")); assertions.query("SELECT * FROM mock.tiny.nation_with_hidden_column") .assertThat() @@ -703,7 +703,7 @@ public void testRowFilterOnHiddenColumn() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG, "tiny", "nation_with_hidden_column"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "\"$hidden\" < 1")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "\"$hidden\" < 1")); assertions.query("SELECT count(*) FROM mock.tiny.nation_with_hidden_column") .assertThat() @@ -730,7 +730,7 @@ public void testRowFilterOnOptionalColumn() accessControl.rowFilter( new QualifiedObjectName(MOCK_CATALOG_MISSING_COLUMNS, "tiny", "nation_with_optional_column"), USER, - new ViewExpression(USER, Optional.empty(), Optional.empty(), "length(optional) > 2")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "length(optional) > 2")); assertions.query("INSERT INTO mockmissingcolumns.tiny.nation_with_optional_column(nationkey, name, regionkey, comment, optional) VALUES (0, 'POLAND', 0, 'No comment', 'some string')") .assertThat() diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java b/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java index 7348717c3b3b..7e280aa140e5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/ViewExpression.java @@ -19,12 +19,18 @@ public class ViewExpression { - private final String identity; + private final Optional identity; private final Optional catalog; private final Optional schema; private final String expression; + @Deprecated public ViewExpression(String identity, Optional catalog, Optional schema, String expression) + { + this(Optional.of(identity), catalog, schema, expression); + } + + public ViewExpression(Optional identity, Optional catalog, Optional schema, String expression) { this.identity = requireNonNull(identity, "identity is null"); this.catalog = requireNonNull(catalog, "catalog is null"); @@ -36,7 +42,17 @@ public ViewExpression(String identity, Optional catalog, Optional getSecurityIdentity() { return identity; } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java index 07ddf5b19040..71dc5e284a28 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/CatalogTableAccessControlRule.java @@ -79,14 +79,14 @@ public boolean canSelectColumns(Set columnNames) return tableAccessControlRule.canSelectColumns(columnNames); } - public Optional getColumnMask(String user, String catalog, String schema, String column) + public Optional getColumnMask(String catalog, String schema, String column) { - return tableAccessControlRule.getColumnMask(user, catalog, schema, column); + return tableAccessControlRule.getColumnMask(catalog, schema, column); } - public Optional getFilter(String user, String catalog, String schema) + public Optional getFilter(String catalog, String schema) { - return tableAccessControlRule.getFilter(user, catalog, schema); + return tableAccessControlRule.getFilter(catalog, schema); } Optional toAnyCatalogPermissionsRule() diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java index c4ce581c83cb..590744dd89cc 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java @@ -647,7 +647,7 @@ public List getRowFilters(ConnectorSecurityContext context, Sche ConnectorIdentity identity = context.getIdentity(); return tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName)) - .map(rule -> rule.getFilter(identity.getUser(), catalogName, tableName.getSchemaName())) + .map(rule -> rule.getFilter(catalogName, tableName.getSchemaName())) // we return the first one we find .findFirst() .stream() @@ -665,7 +665,7 @@ public Optional getColumnMask(ConnectorSecurityContext context, ConnectorIdentity identity = context.getIdentity(); List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledSystemRoles(), identity.getGroups(), tableName)) - .map(rule -> rule.getColumnMask(identity.getUser(), catalogName, tableName.getSchemaName(), columnName)) + .map(rule -> rule.getColumnMask(catalogName, tableName.getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java index e1cfac1da06c..20c5d5d4a9bf 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java @@ -965,7 +965,7 @@ public List getRowFilters(SystemSecurityContext context, Catalog Identity identity = context.getIdentity(); return tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table)) - .map(rule -> rule.getFilter(identity.getUser(), table.getCatalogName(), tableName.getSchemaName())) + .map(rule -> rule.getFilter(table.getCatalogName(), tableName.getSchemaName())) // we return the first one we find .findFirst() .stream() @@ -984,7 +984,7 @@ public Optional getColumnMask(SystemSecurityContext context, Cat Identity identity = context.getIdentity(); List masks = tableRules.stream() .filter(rule -> rule.matches(identity.getUser(), identity.getEnabledRoles(), identity.getGroups(), table)) - .map(rule -> rule.getColumnMask(identity.getUser(), table.getCatalogName(), table.getSchemaTableName().getSchemaName(), columnName)) + .map(rule -> rule.getColumnMask(table.getCatalogName(), table.getSchemaTableName().getSchemaName(), columnName)) // we return the first one we find .findFirst() .stream() diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java index e2d2e3f19ce5..13e6b5536e58 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/TableAccessControlRule.java @@ -102,20 +102,20 @@ public boolean canSelectColumns(Set columnNames) return (privileges.contains(SELECT) || privileges.contains(GRANT_SELECT)) && restrictedColumns.stream().noneMatch(columnNames::contains); } - public Optional getColumnMask(String user, String catalog, String schema, String column) + public Optional getColumnMask(String catalog, String schema, String column) { return Optional.ofNullable(columnConstraints.get(column)).flatMap(constraint -> constraint.getMask().map(mask -> new ViewExpression( - constraint.getMaskEnvironment().flatMap(ExpressionEnvironment::getUser).orElse(user), + constraint.getMaskEnvironment().flatMap(ExpressionEnvironment::getUser), Optional.of(catalog), Optional.of(schema), mask))); } - public Optional getFilter(String user, String catalog, String schema) + public Optional getFilter(String catalog, String schema) { return filter.map(filter -> new ViewExpression( - filterEnvironment.flatMap(ExpressionEnvironment::getUser).orElse(user), + filterEnvironment.flatMap(ExpressionEnvironment::getUser), Optional.of(catalog), Optional.of(schema), filter)); diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java index d155baaced98..9b5a568beb5b 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedConnectorAccessControlTest.java @@ -419,7 +419,7 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanSelectFromColumns(userGroup2, myTable, ImmutableSet.of()); assertViewExpressionEquals( accessControl.getColumnMask(userGroup2, myTable, "col_a", VARCHAR).orElseThrow(), - new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); + new ViewExpression(Optional.empty(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); assertEquals( accessControl.getRowFilters(userGroup2, myTable), ImmutableList.of()); @@ -443,18 +443,18 @@ public void testTableRulesForMixedGroupUsers() accessControl.checkCanSelectFromColumns(userGroup3, myTable, ImmutableSet.of()); assertViewExpressionEquals( accessControl.getColumnMask(userGroup3, myTable, "col_a", VARCHAR).orElseThrow(), - new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); + new ViewExpression(Optional.empty(), Optional.of("test_catalog"), Optional.of("my_schema"), "'mask_a'")); List rowFilters = accessControl.getRowFilters(userGroup3, myTable); assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("test_catalog"), Optional.of("my_schema"), "country='US'")); + new ViewExpression(Optional.empty(), Optional.of("test_catalog"), Optional.of("my_schema"), "country='US'")); } private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); + assertEquals(actual.getSecurityIdentity(), expected.getSecurityIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); assertEquals(actual.getExpression(), expected.getExpression(), "Expression"); diff --git a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java index 3b687a4db496..2fa79c6a4bf7 100644 --- a/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java +++ b/lib/trino-plugin-toolkit/src/test/java/io/trino/plugin/base/security/BaseFileBasedSystemAccessControlTest.java @@ -802,7 +802,7 @@ public void testTableRulesForMixedGroupUsers() new CatalogSchemaTableName("some-catalog", "my_schema", "my_table"), "col_a", VARCHAR).orElseThrow(), - new ViewExpression(userGroup2.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "'mask_a'")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("my_schema"), "'mask_a'")); SystemSecurityContext userGroup1Group3 = new SystemSecurityContext(Identity.forUser("user_1_3") .withGroups(ImmutableSet.of("group1", "group3")).build(), Optional.empty()); @@ -821,7 +821,7 @@ public void testTableRulesForMixedGroupUsers() assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression(userGroup3.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("my_schema"), "country='US'")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("my_schema"), "country='US'")); } @Test @@ -1425,7 +1425,7 @@ public void testGetColumnMask() new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked", VARCHAR).orElseThrow(), - new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask'")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask'")); assertViewExpressionEquals( accessControl.getColumnMask( @@ -1433,7 +1433,7 @@ public void testGetColumnMask() new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns"), "masked_with_user", VARCHAR).orElseThrow(), - new ViewExpression("mask-user", Optional.of("some-catalog"), Optional.of("bobschema"), "'mask-with-user'")); + new ViewExpression(Optional.of("mask-user"), Optional.of("some-catalog"), Optional.of("bobschema"), "'mask-with-user'")); } @Test @@ -1449,18 +1449,18 @@ public void testGetRowFilter() assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression(CHARLIE.getIdentity().getUser(), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter')")); + new ViewExpression(Optional.empty(), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter')")); rowFilters = accessControl.getRowFilters(CHARLIE, new CatalogSchemaTableName("some-catalog", "bobschema", "bobcolumns_with_grant")); assertEquals(rowFilters.size(), 1); assertViewExpressionEquals( rowFilters.get(0), - new ViewExpression("filter-user", Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter-with-user')")); + new ViewExpression(Optional.of("filter-user"), Optional.of("some-catalog"), Optional.of("bobschema"), "starts_with(value, 'filter-with-user')")); } private static void assertViewExpressionEquals(ViewExpression actual, ViewExpression expected) { - assertEquals(actual.getIdentity(), expected.getIdentity(), "Identity"); + assertEquals(actual.getSecurityIdentity(), expected.getSecurityIdentity(), "Identity"); assertEquals(actual.getCatalog(), expected.getCatalog(), "Catalog"); assertEquals(actual.getSchema(), expected.getSchema(), "Schema"); assertEquals(actual.getExpression(), expected.getExpression(), "Expression"); diff --git a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java index 3345d6fc63ad..9dde277210bb 100644 --- a/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java +++ b/testing/trino-tests/src/test/java/io/trino/security/TestAccessControl.java @@ -122,6 +122,7 @@ protected QueryRunner createQueryRunner() throws Exception { Session session = testSessionBuilder() + .setSource("test") .setCatalog("blackhole") .setSchema("default") .build(); @@ -969,7 +970,7 @@ public void testAccessControlWithGroupsAndColumnMask() new QualifiedObjectName("blackhole", "default", "orders"), "comment", getSession().getUser(), - new ViewExpression(getSession().getUser(), Optional.empty(), Optional.empty(), "substr(comment,1,3)")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "substr(comment,1,3)")); assertAccessAllowed("SELECT comment FROM orders"); } @@ -983,11 +984,54 @@ public void testAccessControlWithGroupsAndRowFilter() accessControlManager.rowFilter( new QualifiedObjectName("blackhole", "default", "nation"), getSession().getUser(), - new ViewExpression(getSession().getUser(), Optional.empty(), Optional.empty(), "nationkey % 2 = 0")); + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey % 2 = 0")); assertAccessAllowed("SELECT nationkey FROM nation"); } + @Test + public void testAccessControlWithRolesAndColumnMask() + { + String role = "role"; + String user = "user"; + Session session = Session.builder(getSession()) + .setIdentity(Identity.forUser(user) + .withEnabledRoles(ImmutableSet.of(role)) + .build()) + .build(); + systemSecurityMetadata.grantRoles(getSession(), Set.of(role), Set.of(new TrinoPrincipal(USER, user)), false, Optional.empty()); + TestingAccessControlManager accessControlManager = getQueryRunner().getAccessControl(); + accessControlManager.denyIdentityTable((identity, table) -> (identity.getEnabledRoles().contains(role) && "orders".equals(table))); + accessControlManager.columnMask( + new QualifiedObjectName("blackhole", "default", "orders"), + "comment", + getSession().getUser(), + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "substr(comment,1,3)")); + + assertAccessAllowed(session, "SELECT comment FROM orders"); + } + + @Test + public void testAccessControlWithRolesAndRowFilter() + { + String role = "role"; + String user = "user"; + Session session = Session.builder(getSession()) + .setIdentity(Identity.forUser(user) + .withEnabledRoles(ImmutableSet.of(role)) + .build()) + .build(); + systemSecurityMetadata.grantRoles(getSession(), Set.of(role), Set.of(new TrinoPrincipal(USER, user)), false, Optional.empty()); + TestingAccessControlManager accessControlManager = getQueryRunner().getAccessControl(); + accessControlManager.denyIdentityTable((identity, table) -> (identity.getEnabledRoles().contains(role) && "nation".equals(table))); + accessControlManager.rowFilter( + new QualifiedObjectName("blackhole", "default", "nation"), + getSession().getUser(), + new ViewExpression(Optional.empty(), Optional.empty(), Optional.empty(), "nationkey % 2 = 0")); + + assertAccessAllowed(session, "SELECT nationkey FROM nation"); + } + private static final class DenySetPropertiesSystemAccessControl extends AllowAllSystemAccessControl {