Skip to content

Commit

Permalink
Implement join pushdown for char types in PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Feb 23, 2022
1 parent 845d71a commit 9c3c6a7
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ protected void assertConditionallyOrderedPushedDown(
}
}

private boolean expectJoinPushdown(String operator)
protected boolean expectJoinPushdown(String operator)
{
if ("IS NOT DISTINCT FROM".equals(operator)) {
// TODO (https://github.com/trinodb/trino/issues/6967) support join pushdown for IS NOT DISTINCT FROM
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.postgresql;

import io.trino.plugin.jdbc.DefaultQueryBuilder;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.spi.type.CharType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

import java.util.stream.Stream;

import static java.lang.String.format;

public class CollationAwareQueryBuilder
extends DefaultQueryBuilder
{
@Override
protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition)
{
boolean needsCollation = Stream.of(condition.getLeftColumn(), condition.getRightColumn())
.map(JdbcColumnHandle::getColumnType)
.anyMatch(CollationAwareQueryBuilder::isCharType);

if (needsCollation) {
return format(
"%s.%s COLLATE \"C\" %s %s.%s COLLATE \"C\"",
leftRelationAlias,
buildJoinColumn(client, condition.getLeftColumn()),
condition.getOperator().getValue(),
rightRelationAlias,
buildJoinColumn(client, condition.getRightColumn()));
}

return super.formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition);
}

private static boolean isCharType(Type type)
{
return type instanceof CharType || type instanceof VarcharType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCon
case LESS_THAN_OR_EQUAL:
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
return false;
return isEnableStringPushdownWithCollate(session);
case EQUAL:
case NOT_EQUAL:
case IS_DISTINCT_FROM:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RemoteQueryCancellationModule;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import org.postgresql.Driver;

import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.trino.plugin.jdbc.JdbcModule.bindSessionPropertiesProvider;

Expand All @@ -40,6 +42,7 @@ public void setup(Binder binder)
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(PostgreSqlClient.class).in(Scopes.SINGLETON);
configBinder(binder).bindConfig(PostgreSqlConfig.class);
bindSessionPropertiesProvider(binder, PostgreSqlSessionProperties.class);
newOptionalBinder(binder, QueryBuilder.class).setBinding().to(CollationAwareQueryBuilder.class).in(Scopes.SINGLETON);
install(new DecimalModule());
install(new RemoteQueryCancellationModule());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.RemoteDatabaseEvent;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.TableScanNode;
Expand All @@ -44,14 +47,21 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_LIMIT_PUSHDOWN;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY;
import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN;
import static io.trino.testing.sql.TestTable.randomTableSuffix;
import static java.lang.Math.round;
import static java.lang.String.format;
Expand Down Expand Up @@ -88,7 +98,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
{
switch (connectorBehavior) {
case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY:
case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_INEQUALITY:
return false;

case SUPPORTS_TOPN_PUSHDOWN:
Expand Down Expand Up @@ -460,6 +469,162 @@ public void testStringPushdownWithCollate()
anyTree(node(TableScanNode.class))));
}

@Test
public void testStringJoinPushdownWithCollate()
{
PlanMatchPattern joinOverTableScans =
node(JoinNode.class,
anyTree(node(TableScanNode.class)),
anyTree(node(TableScanNode.class)));

PlanMatchPattern broadcastJoinOverTableScans =
node(JoinNode.class,
node(TableScanNode.class),
exchange(ExchangeNode.Scope.LOCAL,
exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPLICATE,
node(TableScanNode.class))));

Session sessionWithCollatePushdown = Session.builder(getSession())
.setCatalogSessionProperty("postgresql", "enable_string_pushdown_with_collate", "true")
.build();

Session session = joinPushdownEnabled(sessionWithCollatePushdown);

// Disable DF here for the sake of negative test cases' expected plan. With DF enabled, some operators return in DF's FilterNode and some do not.
Session withoutDynamicFiltering = Session.builder(getSession())
.setSystemProperty("enable_dynamic_filtering", "false")
.setCatalogSessionProperty("postgresql", "enable_string_pushdown_with_collate", "true")
.build();

String notDistinctOperator = "IS NOT DISTINCT FROM";
List<String> nonEqualities = Stream.concat(
Stream.of(JoinCondition.Operator.values())
.filter(operator -> operator != JoinCondition.Operator.EQUAL)
.map(JoinCondition.Operator::getValue),
Stream.of(notDistinctOperator))
.collect(toImmutableList());

try (TestTable nationLowercaseTable = new TestTable(
// If a connector supports Join pushdown, but does not allow CTAS, we need to make the table creation here overridable.
getQueryRunner()::execute,
"nation_lowercase",
"AS SELECT nationkey, lower(name) name, regionkey FROM nation")) {
// basic case
assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")).isFullyPushedDown();

// join over different columns
assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown();

// pushdown when using USING
assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r USING(regionkey)")).isFullyPushedDown();

// varchar equality predicate
assertConditionallyPushedDown(
session,
"SELECT n.name, n2.regionkey FROM nation n JOIN nation n2 ON n.name = n2.name",
true,
joinOverTableScans);
assertConditionallyPushedDown(
session,
format("SELECT n.name, nl.regionkey FROM nation n JOIN %s nl ON n.name = nl.name", nationLowercaseTable.getName()),
true,
joinOverTableScans);

// multiple bigint predicates
assertThat(query(session, "SELECT n.name, c.name FROM nation n JOIN customer c ON n.nationkey = c.nationkey and n.regionkey = c.custkey"))
.isFullyPushedDown();

// inequality
for (String operator : nonEqualities) {
// bigint inequality predicate
assertThat(query(withoutDynamicFiltering, format("SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey %s r.regionkey", operator)))
// Currently no pushdown as inequality predicate is removed from Join to maintain Cross Join and Filter as separate nodes
.isNotFullyPushedDown(broadcastJoinOverTableScans);

// varchar inequality predicate
assertThat(query(withoutDynamicFiltering, format("SELECT n.name, nl.name FROM nation n JOIN %s nl ON n.name %s nl.name", nationLowercaseTable.getName(), operator)))
// Currently no pushdown as inequality predicate is removed from Join to maintain Cross Join and Filter as separate nodes
.isNotFullyPushedDown(broadcastJoinOverTableScans);
}

// inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join
for (String operator : nonEqualities) {
assertConditionallyPushedDown(
session,
format("SELECT n.name, c.name FROM nation n JOIN customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", operator),
expectJoinPushdown(operator),
joinOverTableScans);
}

// varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join
for (String operator : nonEqualities) {
assertConditionallyPushedDown(
session,
format("SELECT n.name, nl.name FROM nation n JOIN %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", nationLowercaseTable.getName(), operator),
expectJoinPushdown(operator),
joinOverTableScans);
}

// LEFT JOIN
assertThat(query(session, "SELECT r.name, n.name FROM nation n LEFT JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown();
assertThat(query(session, "SELECT r.name, n.name FROM region r LEFT JOIN nation n ON n.nationkey = r.regionkey")).isFullyPushedDown();

// RIGHT JOIN
assertThat(query(session, "SELECT r.name, n.name FROM nation n RIGHT JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown();
assertThat(query(session, "SELECT r.name, n.name FROM region r RIGHT JOIN nation n ON n.nationkey = r.regionkey")).isFullyPushedDown();

// FULL JOIN
assertConditionallyPushedDown(
session,
"SELECT r.name, n.name FROM nation n FULL JOIN region r ON n.nationkey = r.regionkey",
true,
joinOverTableScans);

// Join over a (double) predicate
assertThat(query(session, "" +
"SELECT c.name, n.name " +
"FROM (SELECT * FROM customer WHERE acctbal > 8000) c " +
"JOIN nation n ON c.custkey = n.nationkey"))
.isFullyPushedDown();

// Join over a varchar equality predicate
assertConditionallyPushedDown(
session,
"SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address = 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " +
"JOIN nation n ON c.custkey = n.nationkey",
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY),
joinOverTableScans);

// join over aggregation
assertConditionallyPushedDown(
session,
"SELECT * FROM (SELECT regionkey rk, count(nationkey) c FROM nation GROUP BY regionkey) n " +
"JOIN region r ON n.rk = r.regionkey",
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN),
joinOverTableScans);

// join over LIMIT
assertConditionallyPushedDown(
session,
"SELECT * FROM (SELECT nationkey FROM nation LIMIT 30) n " +
"JOIN region r ON n.nationkey = r.regionkey",
hasBehavior(SUPPORTS_LIMIT_PUSHDOWN),
joinOverTableScans);

// join over TopN
assertConditionallyPushedDown(
session,
"SELECT * FROM (SELECT nationkey FROM nation ORDER BY regionkey LIMIT 5) n " +
"JOIN region r ON n.nationkey = r.regionkey",
hasBehavior(SUPPORTS_TOPN_PUSHDOWN),
joinOverTableScans);

// join over join
assertThat(query(session, "SELECT * FROM nation n, region r, customer c WHERE n.regionkey = r.regionkey AND r.regionkey = c.custkey"))
.isFullyPushedDown();
}
}

@Test
public void testDecimalPredicatePushdown()
throws Exception
Expand Down

0 comments on commit 9c3c6a7

Please sign in to comment.