From 7100a58ae2988152e0dc382999dcbfe59a97363d Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 16 Mar 2022 13:14:19 +0100 Subject: [PATCH 1/2] Translate NOT, IS NULL, NOT IS NULL to connector expression(s) --- .../ConnectorExpressionTranslator.java | 82 +++++++++++++++++++ .../TestConnectorExpressionTranslator.java | 39 +++++++++ .../spi/expression/StandardFunctions.java | 7 ++ 3 files changed, 128 insertions(+) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index b9fee5704757..9314ce444d82 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -49,11 +49,14 @@ import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.GenericLiteral; +import io.trino.sql.tree.IsNotNullPredicate; +import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.StringLiteral; @@ -79,6 +82,7 @@ import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; @@ -86,6 +90,7 @@ import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.NOT_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -202,6 +207,21 @@ protected Optional translateCall(Call call) if (OR_FUNCTION_NAME.equals(call.getFunctionName())) { return translateLogicalExpression(LogicalExpression.Operator.OR, call.getArguments()); } + if (NOT_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { + ConnectorExpression expression = getOnlyElement(call.getArguments()); + + if (expression instanceof Call) { + Call innerCall = (Call) expression; + if (innerCall.getFunctionName().equals(IS_NULL_FUNCTION_NAME) && innerCall.getArguments().size() == 1) { + return translateIsNotNull(innerCall.getArguments().get(0)); + } + } + + return translateNot(expression); + } + if (IS_NULL_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { + return translateIsNull(call.getArguments().get(0)); + } // comparisons if (call.getArguments().size() == 2) { @@ -251,6 +271,36 @@ protected Optional translateCall(Call call) return Optional.of(builder.build()); } + private Optional translateIsNotNull(ConnectorExpression argument) + { + Optional translatedArgument = translate(argument); + if (translatedArgument.isPresent()) { + return Optional.of(new IsNotNullPredicate(translatedArgument.get())); + } + + return Optional.empty(); + } + + private Optional translateIsNull(ConnectorExpression argument) + { + Optional translatedArgument = translate(argument); + if (translatedArgument.isPresent()) { + return Optional.of(new IsNullPredicate(translatedArgument.get())); + } + + return Optional.empty(); + } + + private Optional translateNot(ConnectorExpression argument) + { + Optional translatedArgument = translate(argument); + if (argument.getType().equals(BOOLEAN) && translatedArgument.isPresent()) { + return Optional.of(new NotExpression(translatedArgument.get())); + } + + return Optional.empty(); + } + private Optional translateLogicalExpression(LogicalExpression.Operator operator, List arguments) { ImmutableList.Builder translatedArguments = ImmutableList.builderWithExpectedSize(arguments.size()); @@ -538,6 +588,38 @@ protected Optional visitFunctionCall(FunctionCall node, Voi return Optional.of(new Call(typeOf(node), name, arguments.build())); } + @Override + protected Optional visitIsNullPredicate(IsNullPredicate node, Void context) + { + Optional translatedValue = process(node.getValue()); + if (translatedValue.isPresent()) { + return Optional.of(new Call(BOOLEAN, IS_NULL_FUNCTION_NAME, ImmutableList.of(translatedValue.get()))); + } + return Optional.empty(); + } + + @Override + protected Optional visitIsNotNullPredicate(IsNotNullPredicate node, Void context) + { + // IS NOT NULL is translated to $not($is_null(..)) + Optional translatedValue = process(node.getValue()); + if (translatedValue.isPresent()) { + Call isNullCall = new Call(typeOf(node), IS_NULL_FUNCTION_NAME, List.of(translatedValue.get())); + return Optional.of(new Call(BOOLEAN, NOT_FUNCTION_NAME, List.of(isNullCall))); + } + return Optional.empty(); + } + + @Override + protected Optional visitNotExpression(NotExpression node, Void context) + { + Optional translatedValue = process(node.getValue()); + if (translatedValue.isPresent()) { + return Optional.of(new Call(BOOLEAN, NOT_FUNCTION_NAME, List.of(translatedValue.get()))); + } + return Optional.empty(); + } + private ConnectorExpression constantFor(Expression node) { Type type = typeOf(node); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index ff562b40fd3f..42ad22ac9a47 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -29,9 +29,12 @@ import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.IsNotNullPredicate; +import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SubscriptExpression; @@ -48,8 +51,10 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.NOT_FUNCTION_NAME; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DecimalType.createDecimalType; @@ -82,6 +87,7 @@ public class TestConnectorExpressionTranslator .put(new Symbol("double_symbol_2"), DOUBLE) .put(new Symbol("row_symbol_1"), ROW_TYPE) .put(new Symbol("varchar_symbol_1"), VARCHAR_TYPE) + .put(new Symbol("boolean_symbol_1"), BOOLEAN) .buildOrThrow(); private static final TypeProvider TYPE_PROVIDER = TypeProvider.copyOf(symbols); @@ -243,6 +249,39 @@ public void testTranslateLike() new Constant(Slices.wrappedBuffer(escape.getBytes(UTF_8)), createVarcharType(escape.length()))))); } + @Test + public void testTranslateIsNull() + { + assertTranslationRoundTrips( + new IsNullPredicate(new SymbolReference("varchar_symbol_1")), + new Call( + BOOLEAN, + IS_NULL_FUNCTION_NAME, + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); + } + + @Test + public void testTranslateNotExpression() + { + assertTranslationRoundTrips( + new NotExpression(new SymbolReference("boolean_symbol_1")), + new Call( + BOOLEAN, + NOT_FUNCTION_NAME, + List.of(new Variable("boolean_symbol_1", BOOLEAN)))); + } + + @Test + public void testTranslateIsNotNull() + { + assertTranslationRoundTrips( + new IsNotNullPredicate(new SymbolReference("varchar_symbol_1")), + new Call( + BOOLEAN, + NOT_FUNCTION_NAME, + List.of(new Call(BOOLEAN, IS_NULL_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))))); + } + @Test public void testTranslateResolvedFunction() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java index a7ae39b861de..16fd90102878 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java @@ -27,6 +27,13 @@ private StandardFunctions() {} */ public static final FunctionName OR_FUNCTION_NAME = new FunctionName("$or"); + /** + * $not is a function accepting boolean argument + */ + public static final FunctionName NOT_FUNCTION_NAME = new FunctionName("$not"); + + public static final FunctionName IS_NULL_FUNCTION_NAME = new FunctionName("$is_null"); + public static final FunctionName EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$equal"); public static final FunctionName NOT_EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$not_equal"); public static final FunctionName LESS_THAN_OPERATOR_FUNCTION_NAME = new FunctionName("$less_than"); From 3a20817dcfa4f14bf4710db71f6c31aceaa2fd51 Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 16 Mar 2022 14:07:20 +0100 Subject: [PATCH 2/2] Implement NOT, IS NULL, NOT IS NULL pushdown in PostgreSQL connector --- .../plugin/postgresql/PostgreSqlClient.java | 3 + .../postgresql/TestPostgreSqlClient.java | 43 ++++++++++++++ .../TestPostgreSqlConnectorTest.java | 58 ++++++++++++++++++- 3 files changed, 103 insertions(+), 1 deletion(-) diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 410134c357e4..08916fb6e28d 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -286,6 +286,9 @@ public PostgreSqlClient( .add(new RewriteComparison(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL)) .map("$like_pattern(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") .map("$like_pattern(value: varchar, pattern: varchar, escape: varchar(1)): boolean").to("value LIKE pattern ESCAPE escape") + .map("$not($is_null(value))").to("value IS NOT NULL") + .map("$not(value: boolean)").to("NOT value") + .map("$is_null(value)").to("value IS NULL") .build(); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index ebfb2352decd..284000e5b4a7 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -32,8 +32,11 @@ import io.trino.sql.planner.TypeProvider; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; +import io.trino.sql.tree.IsNotNullPredicate; +import io.trino.sql.tree.IsNullPredicate; import io.trino.sql.tree.LikePredicate; import io.trino.sql.tree.LogicalExpression; +import io.trino.sql.tree.NotExpression; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; import org.testng.annotations.DataProvider; @@ -276,6 +279,46 @@ public void testConvertLike() .hasValue("(\"c_varchar\") LIKE ('%pattern\\%') ESCAPE ('\\')"); } + @Test + public void testConvertIsNull() + { + // c_varchar IS NULL + assertThat(JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new IsNullPredicate( + new SymbolReference("c_varchar_symbol")), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN))) + .hasValue("(\"c_varchar\") IS NULL"); + } + + @Test + public void testConvertIsNotNull() + { + // c_varchar IS NOT NULL + assertThat(JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new IsNotNullPredicate( + new SymbolReference("c_varchar_symbol")), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN))) + .hasValue("(\"c_varchar\") IS NOT NULL"); + } + + @Test + public void testConvertNotExpression() + { + // NOT(expression) + assertThat(JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new NotExpression( + new IsNotNullPredicate( + new SymbolReference("c_varchar_symbol"))), + Map.of("c_varchar_symbol", VARCHAR_COLUMN.getColumnType())), + Map.of("c_varchar_symbol", VARCHAR_COLUMN))) + .hasValue("NOT ((\"c_varchar\") IS NOT NULL)"); + } + private ConnectorExpression translateToConnectorExpression(Expression expression, Map symbolTypes) { return ConnectorExpressionTranslator.translate( diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 5e3f9a737cfd..8b65c1da3969 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -701,7 +701,7 @@ public void testOrPredicatePushdown() assertThat(query("SELECT * FROM nation WHERE nationkey != 3 OR regionkey = 4")).isFullyPushedDown(); assertThat(query("SELECT * FROM nation WHERE nationkey != 3 OR regionkey != 4")).isFullyPushedDown(); assertThat(query("SELECT * FROM nation WHERE name = 'ALGERIA' OR regionkey = 4")).isFullyPushedDown(); - assertThat(query("SELECT * FROM nation WHERE name IS NULL OR regionkey = 4")).isNotFullyPushedDown(FilterNode.class); // TODO `name IS NULL` is not pushed down + assertThat(query("SELECT * FROM nation WHERE name IS NULL OR regionkey = 4")).isFullyPushedDown(); assertThat(query("SELECT * FROM nation WHERE name = NULL OR regionkey = 4")).isNotFullyPushedDown(FilterNode.class); // TODO `name = NULL` should be eliminated by the engine } @@ -750,6 +750,62 @@ public void testLikeWithEscapePredicatePushdown() } } + @Test + public void testIsNullPredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL")).isFullyPushedDown(); + assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL OR regionkey = 4")).isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_is_null_predicate_pushdown", + "(a_int integer, a_varchar varchar(1))", + List.of( + "1, 'A'", + "2, 'B'", + "1, NULL", + "2, NULL"))) { + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NULL OR a_int = 1")).isFullyPushedDown(); + } + } + + @Test + public void testIsNotNullPredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE name IS NOT NULL OR regionkey = 4")).isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_is_not_null_predicate_pushdown", + "(a_int integer, a_varchar varchar(1))", + List.of( + "1, 'A'", + "2, 'B'", + "1, NULL", + "2, NULL"))) { + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NOT NULL OR a_int = 1")).isFullyPushedDown(); + } + } + + @Test + public void testNotExpressionPushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE NOT(name LIKE '%A%' ESCAPE '\\')")).isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_is_not_predicate_pushdown", + "(a_int integer, a_varchar varchar(2))", + List.of( + "1, 'Aa'", + "2, 'Bb'", + "1, NULL", + "2, NULL"))) { + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE NOT(a_varchar LIKE 'A%') OR a_int = 2")).isFullyPushedDown(); + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE NOT(a_varchar LIKE 'A%' OR a_int = 2)")).isFullyPushedDown(); + } + } + @Override protected String errorMessageForInsertIntoNotNullColumn(String columnName) {