diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index f55642cef168..465ab6bc24b3 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -1243,7 +1243,7 @@ protected void assertConditionallyOrderedPushedDown( } } - private boolean expectJoinPushdown(String operator) + protected boolean expectJoinPushdown(String operator) { if ("IS NOT DISTINCT FROM".equals(operator)) { // TODO (https://github.com/trinodb/trino/issues/6967) support join pushdown for IS NOT DISTINCT FROM diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java new file mode 100644 index 000000000000..52cbfb9604cb --- /dev/null +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/CollationAwareQueryBuilder.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.postgresql; + +import io.trino.plugin.jdbc.DefaultQueryBuilder; +import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcJoinCondition; +import io.trino.spi.type.CharType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; + +import java.util.stream.Stream; + +import static java.lang.String.format; + +public class CollationAwareQueryBuilder + extends DefaultQueryBuilder +{ + @Override + protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition) + { + boolean needsCollation = Stream.of(condition.getLeftColumn(), condition.getRightColumn()) + .map(JdbcColumnHandle::getColumnType) + .anyMatch(CollationAwareQueryBuilder::isCharType); + + if (needsCollation) { + return format( + "%s.%s COLLATE \"C\" %s %s.%s COLLATE \"C\"", + leftRelationAlias, + buildJoinColumn(client, condition.getLeftColumn()), + condition.getOperator().getValue(), + rightRelationAlias, + buildJoinColumn(client, condition.getRightColumn())); + } + + return super.formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition); + } + + private static boolean isCharType(Type type) + { + return type instanceof CharType || type instanceof VarcharType; + } +} diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index e0f778dc124c..191bd672e2a9 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -830,7 +830,7 @@ protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCon case LESS_THAN_OR_EQUAL: case GREATER_THAN: case GREATER_THAN_OR_EQUAL: - return false; + return isEnableStringPushdownWithCollate(session); case EQUAL: case NOT_EQUAL: case IS_DISTINCT_FROM: diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java index 17e21438f193..e75fbea4d7f2 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClientModule.java @@ -24,10 +24,12 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteQueryCancellationModule; import io.trino.plugin.jdbc.credential.CredentialProvider; import org.postgresql.Driver; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; import static io.airlift.configuration.ConfigBinder.configBinder; import static io.trino.plugin.jdbc.JdbcModule.bindSessionPropertiesProvider; @@ -40,6 +42,7 @@ public void setup(Binder binder) binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(PostgreSqlClient.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(PostgreSqlConfig.class); bindSessionPropertiesProvider(binder, PostgreSqlSessionProperties.class); + newOptionalBinder(binder, QueryBuilder.class).setBinding().to(CollationAwareQueryBuilder.class).in(Scopes.SINGLETON); install(new DecimalModule()); install(new RemoteQueryCancellationModule()); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 360f15589219..6ac29b10bce1 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -22,8 +22,11 @@ import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.RemoteDatabaseEvent; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.JoinCondition; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.TableScanNode; @@ -44,14 +47,21 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.stream.Stream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_LIMIT_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN; import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.Math.round; import static java.lang.String.format; @@ -88,7 +98,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: - case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_INEQUALITY: return false; case SUPPORTS_TOPN_PUSHDOWN: @@ -460,6 +469,162 @@ public void testStringPushdownWithCollate() anyTree(node(TableScanNode.class)))); } + @Test + public void testStringJoinPushdownWithCollate() + { + PlanMatchPattern joinOverTableScans = + node(JoinNode.class, + anyTree(node(TableScanNode.class)), + anyTree(node(TableScanNode.class))); + + PlanMatchPattern broadcastJoinOverTableScans = + node(JoinNode.class, + node(TableScanNode.class), + exchange(ExchangeNode.Scope.LOCAL, + exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPLICATE, + node(TableScanNode.class)))); + + Session sessionWithCollatePushdown = Session.builder(getSession()) + .setCatalogSessionProperty("postgresql", "enable_string_pushdown_with_collate", "true") + .build(); + + Session session = joinPushdownEnabled(sessionWithCollatePushdown); + + // Disable DF here for the sake of negative test cases' expected plan. With DF enabled, some operators return in DF's FilterNode and some do not. + Session withoutDynamicFiltering = Session.builder(getSession()) + .setSystemProperty("enable_dynamic_filtering", "false") + .setCatalogSessionProperty("postgresql", "enable_string_pushdown_with_collate", "true") + .build(); + + String notDistinctOperator = "IS NOT DISTINCT FROM"; + List nonEqualities = Stream.concat( + Stream.of(JoinCondition.Operator.values()) + .filter(operator -> operator != JoinCondition.Operator.EQUAL) + .map(JoinCondition.Operator::getValue), + Stream.of(notDistinctOperator)) + .collect(toImmutableList()); + + try (TestTable nationLowercaseTable = new TestTable( + // If a connector supports Join pushdown, but does not allow CTAS, we need to make the table creation here overridable. + getQueryRunner()::execute, + "nation_lowercase", + "AS SELECT nationkey, lower(name) name, regionkey FROM nation")) { + // basic case + assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")).isFullyPushedDown(); + + // join over different columns + assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown(); + + // pushdown when using USING + assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r USING(regionkey)")).isFullyPushedDown(); + + // varchar equality predicate + assertConditionallyPushedDown( + session, + "SELECT n.name, n2.regionkey FROM nation n JOIN nation n2 ON n.name = n2.name", + true, + joinOverTableScans); + assertConditionallyPushedDown( + session, + format("SELECT n.name, nl.regionkey FROM nation n JOIN %s nl ON n.name = nl.name", nationLowercaseTable.getName()), + true, + joinOverTableScans); + + // multiple bigint predicates + assertThat(query(session, "SELECT n.name, c.name FROM nation n JOIN customer c ON n.nationkey = c.nationkey and n.regionkey = c.custkey")) + .isFullyPushedDown(); + + // inequality + for (String operator : nonEqualities) { + // bigint inequality predicate + assertThat(query(withoutDynamicFiltering, format("SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey %s r.regionkey", operator))) + // Currently no pushdown as inequality predicate is removed from Join to maintain Cross Join and Filter as separate nodes + .isNotFullyPushedDown(broadcastJoinOverTableScans); + + // varchar inequality predicate + assertThat(query(withoutDynamicFiltering, format("SELECT n.name, nl.name FROM nation n JOIN %s nl ON n.name %s nl.name", nationLowercaseTable.getName(), operator))) + // Currently no pushdown as inequality predicate is removed from Join to maintain Cross Join and Filter as separate nodes + .isNotFullyPushedDown(broadcastJoinOverTableScans); + } + + // inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join + for (String operator : nonEqualities) { + assertConditionallyPushedDown( + session, + format("SELECT n.name, c.name FROM nation n JOIN customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", operator), + expectJoinPushdown(operator), + joinOverTableScans); + } + + // varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join + for (String operator : nonEqualities) { + assertConditionallyPushedDown( + session, + format("SELECT n.name, nl.name FROM nation n JOIN %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", nationLowercaseTable.getName(), operator), + expectJoinPushdown(operator), + joinOverTableScans); + } + + // LEFT JOIN + assertThat(query(session, "SELECT r.name, n.name FROM nation n LEFT JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown(); + assertThat(query(session, "SELECT r.name, n.name FROM region r LEFT JOIN nation n ON n.nationkey = r.regionkey")).isFullyPushedDown(); + + // RIGHT JOIN + assertThat(query(session, "SELECT r.name, n.name FROM nation n RIGHT JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown(); + assertThat(query(session, "SELECT r.name, n.name FROM region r RIGHT JOIN nation n ON n.nationkey = r.regionkey")).isFullyPushedDown(); + + // FULL JOIN + assertConditionallyPushedDown( + session, + "SELECT r.name, n.name FROM nation n FULL JOIN region r ON n.nationkey = r.regionkey", + true, + joinOverTableScans); + + // Join over a (double) predicate + assertThat(query(session, "" + + "SELECT c.name, n.name " + + "FROM (SELECT * FROM customer WHERE acctbal > 8000) c " + + "JOIN nation n ON c.custkey = n.nationkey")) + .isFullyPushedDown(); + + // Join over a varchar equality predicate + assertConditionallyPushedDown( + session, + "SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address = 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " + + "JOIN nation n ON c.custkey = n.nationkey", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY), + joinOverTableScans); + + // join over aggregation + assertConditionallyPushedDown( + session, + "SELECT * FROM (SELECT regionkey rk, count(nationkey) c FROM nation GROUP BY regionkey) n " + + "JOIN region r ON n.rk = r.regionkey", + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN), + joinOverTableScans); + + // join over LIMIT + assertConditionallyPushedDown( + session, + "SELECT * FROM (SELECT nationkey FROM nation LIMIT 30) n " + + "JOIN region r ON n.nationkey = r.regionkey", + hasBehavior(SUPPORTS_LIMIT_PUSHDOWN), + joinOverTableScans); + + // join over TopN + assertConditionallyPushedDown( + session, + "SELECT * FROM (SELECT nationkey FROM nation ORDER BY regionkey LIMIT 5) n " + + "JOIN region r ON n.nationkey = r.regionkey", + hasBehavior(SUPPORTS_TOPN_PUSHDOWN), + joinOverTableScans); + + // join over join + assertThat(query(session, "SELECT * FROM nation n, region r, customer c WHERE n.regionkey = r.regionkey AND r.regionkey = c.custkey")) + .isFullyPushedDown(); + } + } + @Test public void testDecimalPredicatePushdown() throws Exception