diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java index a9cfc0289d35..ecbdcf0bd599 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java @@ -34,11 +34,10 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.Predicate; import java.util.function.ToIntFunction; -import java.util.stream.Collectors; import java.util.stream.Stream; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -215,17 +214,34 @@ public EqualityPartition generateEqualitiesPartitionedBy(Set scope) .forEach(scopeComplementEqualities::add); } - // Compile the scope straddling equality expressions - List connectingExpressions = new ArrayList<>(); - connectingExpressions.add(matchingCanonical); - connectingExpressions.add(complementCanonical); - connectingExpressions.addAll(scopeStraddlingExpressions); - connectingExpressions = connectingExpressions.stream() - .filter(Objects::nonNull) - .collect(Collectors.toList()); - Expression connectingCanonical = getCanonical(connectingExpressions.stream()); + // Compile single equality between matching and complement scope. + // Only consider expressions that don't have derived expression in other scope. + // Otherwise, redundant equality would be generated. + Optional matchingConnecting = scopeExpressions.stream() + .filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, symbol -> !scope.contains(symbol), false) == null) + .min(canonicalComparator); + Optional complementConnecting = scopeComplementExpressions.stream() + .filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, scope::contains, false) == null) + .min(canonicalComparator); + if (matchingConnecting.isPresent() && complementConnecting.isPresent() && !matchingConnecting.equals(complementConnecting)) { + scopeStraddlingEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingConnecting.get(), complementConnecting.get())); + } + + // Compile the scope straddling equality expressions. + // scopeStraddlingExpressions couldn't be pushed to either side, + // therefore there needs to be an equality generated with + // one of the scopes (either matching or complement). + List straddlingExpressions = new ArrayList<>(); + if (matchingCanonical != null) { + straddlingExpressions.add(matchingCanonical); + } + else if (complementCanonical != null) { + straddlingExpressions.add(complementCanonical); + } + straddlingExpressions.addAll(scopeStraddlingExpressions); + Expression connectingCanonical = getCanonical(straddlingExpressions.stream()); if (connectingCanonical != null) { - connectingExpressions.stream() + straddlingExpressions.stream() .filter(expression -> !expression.equals(connectingCanonical)) .map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, connectingCanonical, expression)) .forEach(scopeStraddlingEqualities::add); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java index a1803ba70ec2..5fb02b62ae8f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java @@ -13,6 +13,7 @@ */ package io.trino.sql.planner; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.sql.planner.assertions.BasePlanTest; @@ -39,6 +40,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -465,6 +467,66 @@ WITH t(a) AS (VALUES 'a', 'b') output(values("field", "field_0"))); } + @Test + public void testSimplifyNonInferrableInheritedPredicate() + { + assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = regionkey AND regionkey = 5) a, nation b WHERE a.nationkey = b.nationkey AND a.nationkey + 11 > 15", + output( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("((L_NATIONKEY = L_REGIONKEY) AND (L_REGIONKEY = BIGINT '5'))", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); + } + + @Test + public void testDoesNotCreatePredicateFromInferredPredicate() + { + assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria("L_NATIONKEY", "R_NATIONKEY") + .left( + filter("true", // DF filter + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) + .right( + anyTree( + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))); + + assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("L_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); + } + + @Test + public void testSimplifiesStraddlingPredicate() + { + assertPlan("SELECT * FROM (SELECT * FROM NATION WHERE nationkey = 5) a JOIN nation b ON a.nationkey = b.nationkey AND a.nationkey = a.regionkey + b.regionkey", + output( + filter("L_REGIONKEY + R_REGIONKEY = BIGINT '5'", + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("L_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey", "R_REGIONKEY", "regionkey"))))))))); + } + protected Session noSemiJoinRewrite() { return Session.builder(getQueryRunner().getDefaultSession()) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java index b3174c7a48a7..2c03203834d5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java @@ -65,6 +65,29 @@ public class TestEqualityInference private final TestingFunctionResolution functionResolution = new TestingFunctionResolution(); private final Metadata metadata = functionResolution.getMetadata(); + @Test + public void testDoesNotInferRedundantStraddlingPredicates() + { + EqualityInference inference = new EqualityInference( + metadata, + equals("a1", "b1"), + equals(add(nameReference("a1"), number(1)), number(0)), + equals(nameReference("a2"), add(nameReference("a1"), number(2))), + equals(nameReference("a1"), add("a3", "b3")), + equals(nameReference("b2"), add("a4", "b4"))); + EqualityInference.EqualityPartition partition = inference.generateEqualitiesPartitionedBy(symbols("a1", "a2", "a3", "a4")); + assertThat(partition.getScopeEqualities()).containsExactly( + equals(number(0), add(nameReference("a1"), number(1))), + equals(nameReference("a2"), add(nameReference("a1"), number(2)))); + assertThat(partition.getScopeComplementEqualities()).containsExactly( + equals(number(0), add(nameReference("b1"), number(1)))); + // there shouldn't be equality a2 = b1 + 1 as it can be derived from a2 = a1 + 1, a1 = b1 + assertThat(partition.getScopeStraddlingEqualities()).containsExactly( + equals("a1", "b1"), + equals(nameReference("a1"), add("a3", "b3")), + equals(nameReference("b2"), add("a4", "b4"))); + } + @Test public void testTransitivity() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java index 1c4e627e4742..92fa340168ef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java @@ -13,6 +13,7 @@ */ package io.trino.sql.planner; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.sql.planner.plan.ExchangeNode; @@ -23,6 +24,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.output; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; @@ -173,4 +175,31 @@ public void testNonStraddlingJoinExpression() anyTree( tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))))))); } + + @Override + @Test + public void testDoesNotCreatePredicateFromInferredPredicate() + { + assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria("L_NATIONKEY", "R_NATIONKEY") + .left( + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))) + .right( + anyTree( + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))); + + assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("L_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); + } }