From 6c89d7c752d8839e94dab3ff2ec534fd55b6b904 Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Fri, 22 Apr 2022 13:29:01 +0200 Subject: [PATCH 1/5] Fix varchar rewrite for NULL literal --- .../trino/plugin/jdbc/expression/RewriteVarcharConstant.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java index f63d1b399d68..94ada1e9ecd4 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java @@ -40,6 +40,9 @@ public Pattern getPattern() public Optional rewrite(Constant constant, Captures captures, RewriteContext context) { Slice slice = (Slice) constant.getValue(); + if (slice == null) { + return Optional.empty(); + } return Optional.of("'" + slice.toStringUtf8().replace("'", "''") + "'"); } } From c568122e1dbfac7d60f9c3b331303f2d2c04c644 Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 9 Mar 2022 15:49:15 +0100 Subject: [PATCH 2/5] Translate IN predicate to connector expression --- .../ConnectorExpressionTranslator.java | 85 +++++++++++++++++-- .../TestConnectorExpressionTranslator.java | 22 +++++ .../spi/expression/StandardFunctions.java | 12 +++ 3 files changed, 110 insertions(+), 9 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 2084e9606fd1..21b78d55e7a0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -28,6 +28,7 @@ import io.trino.spi.expression.FieldDereference; import io.trino.spi.expression.FunctionName; import io.trino.spi.expression.Variable; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Decimals; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; @@ -50,6 +51,8 @@ import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.GenericLiteral; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; @@ -80,11 +83,13 @@ import static io.trino.SystemSessionProperties.isComplexExpressionPushdown; import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; @@ -272,6 +277,10 @@ protected Optional translateCall(Call call) } } + if (IN_PREDICATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 2) { + return translateInPredicate(call.getArguments().get(0), call.getArguments().get(1)); + } + QualifiedName name = QualifiedName.of(call.getFunctionName().getName()); List argumentTypes = call.getArguments().stream() .map(argument -> argument.getType().getTypeSignature()) @@ -344,15 +353,8 @@ private Optional translateCast(Type type, ConnectorExpression expres private Optional translateLogicalExpression(LogicalExpression.Operator operator, List arguments) { - ImmutableList.Builder translatedArguments = ImmutableList.builderWithExpectedSize(arguments.size()); - for (ConnectorExpression argument : arguments) { - Optional translated = translate(argument); - if (translated.isEmpty()) { - return Optional.empty(); - } - translatedArguments.add(translated.get()); - } - return Optional.of(new LogicalExpression(operator, translatedArguments.build())); + Optional> translatedArguments = translateExpressions(arguments); + return translatedArguments.map(expressions -> new LogicalExpression(operator, expressions)); } private Optional translateComparison(ComparisonExpression.Operator operator, ConnectorExpression left, ConnectorExpression right) @@ -446,6 +448,46 @@ protected Optional translateLike(ConnectorExpression value, Connecto return Optional.empty(); } + + protected Optional translateInPredicate(ConnectorExpression value, ConnectorExpression values) + { + Optional translatedValue = translate(value); + Optional> translatedValues = extractExpressionsFromArrayCall(values); + + if (translatedValue.isPresent() && translatedValues.isPresent()) { + return Optional.of(new InPredicate(translatedValue.get(), new InListExpression(translatedValues.get()))); + } + + return Optional.empty(); + } + + protected Optional> extractExpressionsFromArrayCall(ConnectorExpression expression) + { + if (!(expression instanceof Call)) { + return Optional.empty(); + } + + Call call = (Call) expression; + if (!call.getFunctionName().equals(ARRAY_CONSTRUCTOR_FUNCTION_NAME)) { + return Optional.empty(); + } + + return translateExpressions(call.getArguments()); + } + + protected Optional> translateExpressions(List expressions) + { + ImmutableList.Builder translatedExpressions = ImmutableList.builderWithExpectedSize(expressions.size()); + for (ConnectorExpression expression : expressions) { + Optional translated = translate(expression); + if (translated.isEmpty()) { + return Optional.empty(); + } + translatedExpressions.add(translated.get()); + } + + return Optional.of(translatedExpressions.build()); + } } public static class SqlToConnectorExpressionTranslator @@ -760,6 +802,31 @@ protected Optional visitSubscriptExpression(SubscriptExpres return Optional.of(new FieldDereference(typeOf(node), translatedBase.get(), toIntExact(((LongLiteral) node.getIndex()).getValue() - 1))); } + @Override + protected Optional visitInPredicate(InPredicate node, Void context) + { + InListExpression valueList = (InListExpression) node.getValueList(); + Optional valueExpression = process(node.getValue()); + + if (valueExpression.isEmpty()) { + return Optional.empty(); + } + + ImmutableList.Builder values = ImmutableList.builderWithExpectedSize(valueList.getValues().size()); + for (Expression value : valueList.getValues()) { + Optional processedValue = process(value); + + if (processedValue.isEmpty()) { + return Optional.empty(); + } + + values.add(processedValue.get()); + } + + ConnectorExpression arrayExpression = new Call(new ArrayType(typeOf(node.getValueList())), ARRAY_CONSTRUCTOR_FUNCTION_NAME, values.build()); + return Optional.of(new Call(typeOf(node), IN_PREDICATE_FUNCTION_NAME, List.of(valueExpression.get(), arrayExpression))); + } + @Override protected Optional visitExpression(Expression node, Void context) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index f19d68a7441e..675282c4981a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -24,6 +24,7 @@ import io.trino.spi.expression.FunctionName; import io.trino.spi.expression.StandardFunctions; import io.trino.spi.expression.Variable; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.sql.tree.ArithmeticBinaryExpression; @@ -34,6 +35,8 @@ import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; @@ -59,6 +62,7 @@ import static io.airlift.slice.Slices.utf8Slice; import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; @@ -94,6 +98,8 @@ public class TestConnectorExpressionTranslator private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT); private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5))); private static final VarcharType VARCHAR_TYPE = createVarcharType(25); + private static final ArrayType VARCHAR_ARRAY_TYPE = new ArrayType(VARCHAR_TYPE); + private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); private static final Map symbols = ImmutableMap.builder() @@ -418,6 +424,22 @@ public void testTranslateRegularExpression() }); } + @Test + public void testTranslateIn() + { + String value = "value_1"; + assertTranslationRoundTrips( + new InPredicate( + new SymbolReference("varchar_symbol_1"), + new InListExpression(List.of(new SymbolReference("varchar_symbol_1"), new StringLiteral(value)))), + new Call( + BOOLEAN, + StandardFunctions.IN_PREDICATE_FUNCTION_NAME, + List.of( + new Variable("varchar_symbol_1", VARCHAR_TYPE), + new Call(VARCHAR_ARRAY_TYPE, ARRAY_CONSTRUCTOR_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new Constant(Slices.wrappedBuffer(value.getBytes(UTF_8)), createVarcharType(value.length()))))))); + } + private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression) { assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression); diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java index cd7694efb7c1..2784697f6575 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java @@ -82,4 +82,16 @@ private StandardFunctions() {} public static final FunctionName NEGATE_FUNCTION_NAME = new FunctionName("$negate"); public static final FunctionName LIKE_PATTERN_FUNCTION_NAME = new FunctionName("$like_pattern"); + + /** + * {@code $in(value, array)} returns {@code true} when value is equal to an element of the array, + * otherwise returns {@code NULL} when comparing value to an element of the array returns an + * indeterminate result, otherwise returns {@code false} + */ + public static final FunctionName IN_PREDICATE_FUNCTION_NAME = new FunctionName("$in"); + + /** + * $array creates instance of {@link Array Type} + */ + public static final FunctionName ARRAY_CONSTRUCTOR_FUNCTION_NAME = new FunctionName("$array"); } From 4ad0dfd1129ab638584a495a1933c194323e6e6d Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Thu, 10 Mar 2022 13:27:24 +0100 Subject: [PATCH 3/5] Rewrite connector IN expression in PostgreSQL connector --- .../ConnectorExpressionPatterns.java | 5 ++ .../plugin/jdbc/expression/RewriteIn.java | 89 +++++++++++++++++++ .../plugin/postgresql/PostgreSqlClient.java | 2 + .../postgresql/TestPostgreSqlClient.java | 33 ++++++- .../TestPostgreSqlConnectorTest.java | 33 +++++++ 5 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java index ecfe3b28e2ee..7666673300f2 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressionPatterns.java @@ -79,6 +79,11 @@ public static Pattern call() return Property.property("argumentCount", call -> call.getArguments().size()); } + public static Property> arguments() + { + return Property.property("arguments", Call::getArguments); + } + public static Property argument(int argument) { checkArgument(0 <= argument, "Invalid argument index: %s", argument); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java new file mode 100644 index 000000000000..0b81f689b502 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.arguments; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.getDomainCompactionThreshold; +import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static java.lang.String.format; + +public class RewriteIn + implements ConnectorExpressionRule +{ + private static final Capture VALUE = newCapture(); + private static final Capture> EXPRESSIONS = newCapture(); + + private static final Pattern PATTERN = call() + .with(functionName().equalTo(IN_PREDICATE_FUNCTION_NAME)) + .with(type().equalTo(BOOLEAN)) + .with(argumentCount().equalTo(2)) + .with(argument(0).matching(expression().capturedAs(VALUE))) + .with(argument(1).matching(call().with(functionName().equalTo(ARRAY_CONSTRUCTOR_FUNCTION_NAME)).with(arguments().capturedAs(EXPRESSIONS)))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(Call call, Captures captures, RewriteContext context) + { + Optional value = context.defaultRewrite(captures.get(VALUE)); + if (value.isEmpty()) { + return Optional.empty(); + } + + List expressions = captures.get(EXPRESSIONS); + if (expressions.size() > getDomainCompactionThreshold(context.getSession())) { + // We don't want to push down too long IN query text + return Optional.empty(); + } + + ImmutableList.Builder rewrittenValues = ImmutableList.builderWithExpectedSize(expressions.size()); + for (ConnectorExpression expression : expressions) { + Optional rewrittenExpression = context.defaultRewrite(expression); + if (rewrittenExpression.isEmpty()) { + return Optional.empty(); + } + rewrittenValues.add(rewrittenExpression.get()); + } + + List values = rewrittenValues.build(); + verify(!values.isEmpty(), "Empty values"); + return Optional.of(format("(%s) IN (%s)", value.get(), Joiner.on(", ").join(values))); + } +} diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 4dbf44840cd6..67dd6668e0e6 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -65,6 +65,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.RewriteComparison; +import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping; import io.trino.spi.TrinoException; @@ -300,6 +301,7 @@ public PostgreSqlClient( .addStandardRules(this::quoted) // TODO allow all comparison operators for numeric types .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) + .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index a9d6bca510bd..7aae0329b414 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -20,11 +20,14 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcMetadataConfig; +import io.trino.plugin.jdbc.JdbcMetadataSessionProperties; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.mapping.DefaultIdentifierMapping; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorSession; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; @@ -36,6 +39,8 @@ import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.InListExpression; +import io.trino.sql.tree.InPredicate; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; @@ -44,6 +49,7 @@ import io.trino.sql.tree.NullIfExpression; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; +import io.trino.testing.TestingConnectorSession; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -63,7 +69,6 @@ import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.testing.DataProviders.toDataProvider; -import static io.trino.testing.TestingConnectorSession.SESSION; import static io.trino.testing.assertions.Assert.assertEquals; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static java.lang.String.format; @@ -93,6 +98,13 @@ public class TestPostgreSqlClient .setJdbcTypeHandle(new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.empty())) .build(); + private static final JdbcColumnHandle VARCHAR_COLUMN2 = + JdbcColumnHandle.builder() + .setColumnName("c_varchar2") + .setColumnType(createVarcharType(10)) + .setJdbcTypeHandle(new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(10), Optional.empty(), Optional.empty(), Optional.empty())) + .build(); + private static final JdbcClient JDBC_CLIENT = new PostgreSqlClient( new BaseJdbcConfig(), new PostgreSqlConfig(), @@ -104,6 +116,11 @@ public class TestPostgreSqlClient private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT); + private static final ConnectorSession SESSION = TestingConnectorSession + .builder() + .setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig(), Optional.empty()).getSessionProperties()) + .build(); + @Test public void testImplementCount() { @@ -410,6 +427,20 @@ public void testConvertNotExpression() .hasValue("NOT ((\"c_varchar\") IS NOT NULL)"); } + @Test + public void testConvertIn() + { + assertThat(JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new InPredicate( + new SymbolReference("c_varchar"), + new InListExpression(List.of(new StringLiteral("value1"), new StringLiteral("value2"), new SymbolReference("c_varchar2")))), + Map.of("c_varchar", VARCHAR_COLUMN.getColumnType(), "c_varchar2", VARCHAR_COLUMN2.getColumnType())), + Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN, VARCHAR_COLUMN2.getColumnName(), VARCHAR_COLUMN2))) + .hasValue("(\"c_varchar\") IN ('value1', 'value2', \"c_varchar2\")"); + } + private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes) { return ConnectorExpressionTranslator.translate( diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index cae377cbb433..8d6b5680dedf 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -811,6 +811,39 @@ public void testNotExpressionPushdown() } } + @Test + public void testInPredicatePushdown() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_in_predicate_pushdown", + "(id varchar(1), id2 varchar(1))", + List.of( + "'a', 'b'", + "'b', 'c'", + "'c', 'c'", + "'d', 'd'", + "'a', 'f'"))) { + // IN values cannot be represented as a domain + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', id2)")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'b') OR id2 IN ('c', 'd')")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B') OR id2 IN ('c', 'D')")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B', NULL) OR id2 IN ('C', 'd')")) + // NULL constant value is currently not pushed down + .isNotFullyPushedDown(FilterNode.class); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B', CAST(NULL AS varchar(1))) OR id2 IN ('C', 'd')")) + // NULL constant value is currently not pushed down + .isNotFullyPushedDown(FilterNode.class); + } + } + @Override protected String errorMessageForInsertIntoNotNullColumn(String columnName) { From 94432eeb7391565808b9b7f3446fd3549ca71e5c Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Fri, 6 May 2022 16:25:30 +0200 Subject: [PATCH 4/5] Rewrite typed NULL char/varchar constants --- .../io/trino/cost/FilterStatsCalculator.java | 5 +++++ ...JdbcConnectorExpressionRewriterBuilder.java | 9 ++++++++- .../expression/RewriteVarcharConstant.java | 12 +++++++++++- .../plugin/postgresql/PostgreSqlClient.java | 17 ++++++++++++++++- .../postgresql/TestPostgreSqlClient.java | 18 ++++++++++++++++++ .../TestPostgreSqlConnectorTest.java | 6 ++---- 6 files changed, 60 insertions(+), 7 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java index f80055075ea2..2b88c63ae870 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsCalculator.java @@ -476,6 +476,11 @@ private OptionalDouble doubleValueFromLiteral(Type type, Expression literal) session, new AllowAllAccessControl(), ImmutableMap.of()); + + if (literalValue == null) { + return OptionalDouble.empty(); + } + return toStatsRepresentation(type, literalValue); } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java index 3e41269befa1..c6d22ebf12a2 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/JdbcConnectorExpressionRewriterBuilder.java @@ -16,9 +16,11 @@ import com.google.common.collect.ImmutableSet; import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.spi.type.Type; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -39,9 +41,14 @@ public static JdbcConnectorExpressionRewriterBuilder newBuilder() private JdbcConnectorExpressionRewriterBuilder() {} public JdbcConnectorExpressionRewriterBuilder addStandardRules(Function identifierQuote) + { + return addStandardRules(identifierQuote, type -> Optional.empty()); + } + + public JdbcConnectorExpressionRewriterBuilder addStandardRules(Function identifierQuote, Function> typeMapping) { add(new RewriteVariable(identifierQuote)); - add(new RewriteVarcharConstant()); + add(new RewriteVarcharConstant(typeMapping)); add(new RewriteExactNumericConstant()); add(new RewriteAnd()); add(new RewriteOr()); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java index 94ada1e9ecd4..c225225158fe 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteVarcharConstant.java @@ -18,17 +18,27 @@ import io.trino.matching.Pattern; import io.trino.plugin.base.expression.ConnectorExpressionRule; import io.trino.spi.expression.Constant; +import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import java.util.Optional; +import java.util.function.Function; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.constant; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; public class RewriteVarcharConstant implements ConnectorExpressionRule { private static final Pattern PATTERN = constant().with(type().matching(VarcharType.class::isInstance)); + private final Function> typeMapping; + + public RewriteVarcharConstant(Function> typeMapping) + { + this.typeMapping = requireNonNull(typeMapping, "typeMapping is null"); + } @Override public Pattern getPattern() @@ -41,7 +51,7 @@ public Optional rewrite(Constant constant, Captures captures, RewriteCon { Slice slice = (Slice) constant.getValue(); if (slice == null) { - return Optional.empty(); + return typeMapping.apply(constant.getType()).map(typedCast -> format("CAST(NULL AS %s)", typedCast)); } return Optional.of("'" + slice.toStringUtf8().replace("'", "''") + "'"); } diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 67dd6668e0e6..f5aedf05223c 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -298,7 +298,7 @@ public PostgreSqlClient( this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() - .addStandardRules(this::quoted) + .addStandardRules(this::quoted, this::typedCast) // TODO allow all comparison operators for numeric types .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) .add(new RewriteIn()) @@ -341,6 +341,21 @@ public PostgreSqlClient( .build()); } + private Optional typedCast(Type type) + { + if (type instanceof VarcharType) { + VarcharType varcharType = (VarcharType) type; + return varcharType.getLength().map(length -> format("varchar(%d)", length)).or(() -> Optional.of("varchar")); + } + + if (type instanceof CharType) { + CharType charType = (CharType) type; + return Optional.of(format("char(%d)", charType.getLength())); + } + + return Optional.empty(); + } + @Override public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata) { diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index 7aae0329b414..80788a7c70e2 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -37,7 +37,9 @@ import io.trino.sql.planner.TypeProvider; import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; +import io.trino.sql.tree.Cast; import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.DataType; import io.trino.sql.tree.Expression; import io.trino.sql.tree.InListExpression; import io.trino.sql.tree.InPredicate; @@ -47,6 +49,7 @@ import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullIfExpression; +import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; import io.trino.testing.TestingConnectorSession; @@ -66,6 +69,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.testing.DataProviders.toDataProvider; @@ -441,6 +445,20 @@ public void testConvertIn() .hasValue("(\"c_varchar\") IN ('value1', 'value2', \"c_varchar2\")"); } + @Test + public void testConvertInWithNulls() + { + assertThat(JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new InPredicate( + new SymbolReference("c_varchar"), + new InListExpression(List.of(new StringLiteral("value1"), new StringLiteral("value2"), new Cast(new NullLiteral(), toSqlType(VARCHAR_COLUMN.getColumnType())), new SymbolReference("c_varchar2")))), + Map.of("c_varchar", VARCHAR_COLUMN.getColumnType(), "c_varchar2", VARCHAR_COLUMN2.getColumnType())), + Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN, VARCHAR_COLUMN2.getColumnName(), VARCHAR_COLUMN2))) + .hasValue("(\"c_varchar\") IN ('value1', 'value2', CAST(NULL AS varchar(10)), \"c_varchar2\")"); + } + private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes) { return ConnectorExpressionTranslator.translate( diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 8d6b5680dedf..b1a3e7970daa 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -835,12 +835,10 @@ public void testInPredicatePushdown() .isFullyPushedDown(); assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B', NULL) OR id2 IN ('C', 'd')")) - // NULL constant value is currently not pushed down - .isNotFullyPushedDown(FilterNode.class); + .isFullyPushedDown(); assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B', CAST(NULL AS varchar(1))) OR id2 IN ('C', 'd')")) - // NULL constant value is currently not pushed down - .isNotFullyPushedDown(FilterNode.class); + .isFullyPushedDown(); } } From 7390a2540af60e895fa67b0d33cbcd8742d6177c Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Fri, 6 May 2022 16:48:46 +0200 Subject: [PATCH 5/5] Add explicit cast pushdown support to PostgreSQL --- .../plugin/jdbc/expression/RewriteCast.java | 74 +++++++++++++++++++ .../plugin/postgresql/PostgreSqlClient.java | 2 + .../postgresql/TestPostgreSqlClient.java | 25 ++++++- .../TestPostgreSqlConnectorTest.java | 18 +++++ 4 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java new file mode 100644 index 000000000000..0850479119f3 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCast.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.type.Type; + +import java.util.Optional; +import java.util.function.Function; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class RewriteCast + implements ConnectorExpressionRule +{ + private static final Capture ARGUMENT = newCapture(); + + private static final Pattern PATTERN = call() + .with(functionName().equalTo(CAST_FUNCTION_NAME)) + .with(argument(0).capturedAs(ARGUMENT)); + + private final Function> typeMapping; + + public RewriteCast(Function> typeMapping) + { + this.typeMapping = requireNonNull(typeMapping, "typeMapping is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(Call expression, Captures captures, RewriteContext context) + { + ConnectorExpression argument = captures.get(ARGUMENT); + Optional typeCast = typeMapping.apply(expression.getType()); + + if (typeCast.isEmpty()) { + return Optional.empty(); + } + + Optional translatedArgument = context.defaultRewrite(argument); + if (translatedArgument.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(format("CAST(%s AS %s)", translatedArgument.get(), typeCast.get())); + } +} diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index f5aedf05223c..5db27d291c91 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -64,6 +64,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.RewriteCast; import io.trino.plugin.jdbc.expression.RewriteComparison; import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.mapping.IdentifierMapping; @@ -300,6 +301,7 @@ public PostgreSqlClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted, this::typedCast) // TODO allow all comparison operators for numeric types + .add(new RewriteCast(this::typedCast)) .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index 80788a7c70e2..8308c52c6c78 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -39,7 +39,6 @@ import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.Cast; import io.trino.sql.tree.ComparisonExpression; -import io.trino.sql.tree.DataType; import io.trino.sql.tree.Expression; import io.trino.sql.tree.InListExpression; import io.trino.sql.tree.InPredicate; @@ -459,6 +458,30 @@ public void testConvertInWithNulls() .hasValue("(\"c_varchar\") IN ('value1', 'value2', CAST(NULL AS varchar(10)), \"c_varchar2\")"); } + @Test + public void testConvertCast() + { + assertThat(JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new Cast( + new SymbolReference("c_varchar"), + toSqlType(BIGINT_COLUMN.getColumnType())), + Map.of("c_varchar", VARCHAR_COLUMN.getColumnType())), + Map.of(VARCHAR_COLUMN.getColumnName(), VARCHAR_COLUMN))) + .isEmpty(); + + assertThat(JDBC_CLIENT.convertPredicate( + SESSION, + translateToConnectorExpression( + new Cast( + new SymbolReference("c_bigint"), + toSqlType(VARCHAR_COLUMN.getColumnType())), + Map.of("c_bigint", BIGINT_COLUMN.getColumnType())), + Map.of(BIGINT_COLUMN.getColumnName(), BIGINT_COLUMN))) + .hasValue("CAST(\"c_bigint\" AS varchar(10))"); + } + private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes) { return ConnectorExpressionTranslator.translate( diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index b1a3e7970daa..716872b1f416 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -842,6 +842,24 @@ public void testInPredicatePushdown() } } + @Test + public void testCastPushdown() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_cast_pushdown", + "(id bigint, id2 varchar(1))", + List.of( + "1, 'b'", + "2, 'c'", + "3, 'c'", + "4, 'd'", + "5, 'f'"))) { + assertThat(query("SELECT id FROM " + table.getName() + " WHERE CAST(id AS VARCHAR(1)) = '2' OR id2 = 'd'")) + .isFullyPushedDown(); + } + } + @Override protected String errorMessageForInsertIntoNotNullColumn(String columnName) {