From 4bb9dee4dfa4aace50737261bf1b7a13f2a52dd5 Mon Sep 17 00:00:00 2001 From: Sasha Sheikin Date: Fri, 10 May 2024 18:19:39 +0200 Subject: [PATCH] Cast pushdown --- .../plugin/jdbc/expression/RewriteCast.java | 30 +++++++++++-------- .../plugin/postgresql/PostgreSqlClient.java | 2 +- .../TestPostgreSqlConnectorTest.java | 14 +++++++++ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java index 84d46f1579b4..4fd0c66e3667 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java @@ -13,61 +13,65 @@ */ package io.trino.plugin.jdbc.expression; -import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; -import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.WriteMapping; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.type.Type; -import java.util.List; import java.util.Optional; +import java.util.function.BiFunction; import static io.trino.matching.Capture.newCapture; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; public class RewriteCast implements ConnectorExpressionRule { private static final Capture VALUE = newCapture(); - private static final Capture> EXPRESSIONS = newCapture(); + private final BiFunction toWriteMapping; + + public RewriteCast(BiFunction toWriteMapping) + { + this.toWriteMapping = requireNonNull(toWriteMapping, "toWriteMapping is null"); + } @Override public Pattern getPattern() { return call() .with(functionName().equalTo(CAST_FUNCTION_NAME)) + .with(argumentCount().equalTo(1)) .with(argument(0).matching(expression().capturedAs(VALUE))); } @Override public Optional rewrite(Call call, Captures captures, RewriteContext context) { - Type targetType = call.getType(); + Type trinoType = call.getType(); ConnectorExpression capturedValue = captures.get(VALUE); Optional value = context.defaultRewrite(capturedValue); if (value.isEmpty()) { - return Optional.empty(); - } - - ImmutableList.Builder parameters = ImmutableList.builder(); - Optional rewritten = context.defaultRewrite(capturedValue); - if (rewritten.isEmpty()) { // if argument is a call chain that can't be rewritten, then we can't push it down return Optional.empty(); } + String targetType = toWriteMapping.apply(context.getSession(), trinoType).getDataType(); + return Optional.of(new ParameterizedExpression( - format("CAST(%s AS %s)", value.get().expression(), targetType.getDisplayName()), - parameters.build())); + format("CAST(%s AS %s)", value.get().expression(), targetType), + value.get().parameters())); } } diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 229ee2d026f6..4622dc1dc652 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -308,7 +308,7 @@ public PostgreSqlClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) .add(new RewriteIn()) - .add(new RewriteCast()) + .add(new RewriteCast(this::toWriteMapping)) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) .withTypeClass("string_type", ImmutableSet.of("char", "varchar")) diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index f5139e7f8bbb..5889dd5c6b2e 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -62,6 +62,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_LIMIT_PUSHDOWN; import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY; @@ -1013,6 +1014,19 @@ public void testInPredicatePushdown() } } + @Test + public void testJoinWithCastInCriteriaPushdown() + { + if (!hasBehavior(SUPPORTS_JOIN_PUSHDOWN)) { + return; + } + + Session session = joinPushdownEnabled(getSession()); + + assertThat(query(session, "SELECT c.name, o.orderdate FROM customer c JOIN orders o ON CAST(c.custkey AS varchar(20)) = CAST(o.custkey AS varchar(21))")).isFullyPushedDown(); + assertThat(query(session, "SELECT c.name, o.orderdate FROM customer c JOIN orders o ON CAST((c.custkey + 123) AS varchar(20)) = CAST(o.custkey AS varchar(21))")).isFullyPushedDown(); + } + @Override protected String errorMessageForInsertIntoNotNullColumn(String columnName) {