From fb190ef58775252ec9065e52c7b56760719529e1 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 10 Mar 2023 13:22:54 +0100 Subject: [PATCH 1/2] Remove unnecessary simplification step This step is not needed anymore as corresponding tests do not fail. --- .../planner/optimizations/PredicatePushDown.java | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index e4bb453baf72..2321e456fa1e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -1008,16 +1008,10 @@ private InnerJoinPushDownResult processInnerJoin( ImmutableSet leftScope = ImmutableSet.copyOf(leftSymbols); ImmutableSet rightScope = ImmutableSet.copyOf(rightSymbols); - // Attempt to simplify the effective left/right predicates with the predicate we're pushing down - // This, effectively, inlines any constants derived from such predicate - EqualityInference predicateInference = new EqualityInference(metadata, inheritedPredicate); - Expression simplifiedLeftEffectivePredicate = predicateInference.rewrite(leftEffectivePredicate, leftScope); - Expression simplifiedRightEffectivePredicate = predicateInference.rewrite(rightEffectivePredicate, rightScope); - // Generate equality inferences - EqualityInference allInference = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, simplifiedLeftEffectivePredicate, simplifiedRightEffectivePredicate); - EqualityInference allInferenceWithoutLeftInferred = new EqualityInference(metadata, inheritedPredicate, rightEffectivePredicate, joinPredicate, simplifiedRightEffectivePredicate); - EqualityInference allInferenceWithoutRightInferred = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, joinPredicate, simplifiedLeftEffectivePredicate); + EqualityInference allInference = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate); + EqualityInference allInferenceWithoutLeftInferred = new EqualityInference(metadata, inheritedPredicate, rightEffectivePredicate, joinPredicate); + EqualityInference allInferenceWithoutRightInferred = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, joinPredicate); // Add equalities from the inference back in leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(leftScope).getScopeEqualities()); @@ -1043,13 +1037,13 @@ private InnerJoinPushDownResult processInnerJoin( }); // See if we can push the right effective predicate to the left side - EqualityInference.nonInferrableConjuncts(metadata, simplifiedRightEffectivePredicate) + EqualityInference.nonInferrableConjuncts(metadata, rightEffectivePredicate) .map(conjunct -> allInference.rewrite(conjunct, leftScope)) .filter(Objects::nonNull) .forEach(leftPushDownConjuncts::add); // See if we can push the left effective predicate to the right side - EqualityInference.nonInferrableConjuncts(metadata, simplifiedLeftEffectivePredicate) + EqualityInference.nonInferrableConjuncts(metadata, leftEffectivePredicate) .map(conjunct -> allInference.rewrite(conjunct, rightScope)) .filter(Objects::nonNull) .forEach(rightPushDownConjuncts::add); From d3da58971878403cae57fd32b9d6caa2f4c4fda4 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 10 Mar 2023 13:05:13 +0100 Subject: [PATCH 2/2] Do not generate redundant straddling predicates by equality inference When there are predicates like a1 = b1 and a2 = a1 + 1, then equality inference would derive staddling predicate for a2 = b1 + 1, which is redundant to a1 = b1, a2 = a1 + 1. This commit makes sure that redundant straddling predicates are not generated. --- .../trino/sql/planner/EqualityInference.java | 40 ++++++++---- .../AbstractPredicatePushdownTest.java | 62 +++++++++++++++++++ .../sql/planner/TestEqualityInference.java | 23 +++++++ ...PredicatePushdownWithoutDynamicFilter.java | 29 +++++++++ 4 files changed, 142 insertions(+), 12 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java index a9cfc0289d35..ecbdcf0bd599 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/EqualityInference.java @@ -34,11 +34,10 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.Predicate; import java.util.function.ToIntFunction; -import java.util.stream.Collectors; import java.util.stream.Stream; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -215,17 +214,34 @@ public EqualityPartition generateEqualitiesPartitionedBy(Set scope) .forEach(scopeComplementEqualities::add); } - // Compile the scope straddling equality expressions - List connectingExpressions = new ArrayList<>(); - connectingExpressions.add(matchingCanonical); - connectingExpressions.add(complementCanonical); - connectingExpressions.addAll(scopeStraddlingExpressions); - connectingExpressions = connectingExpressions.stream() - .filter(Objects::nonNull) - .collect(Collectors.toList()); - Expression connectingCanonical = getCanonical(connectingExpressions.stream()); + // Compile single equality between matching and complement scope. + // Only consider expressions that don't have derived expression in other scope. + // Otherwise, redundant equality would be generated. + Optional matchingConnecting = scopeExpressions.stream() + .filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, symbol -> !scope.contains(symbol), false) == null) + .min(canonicalComparator); + Optional complementConnecting = scopeComplementExpressions.stream() + .filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, scope::contains, false) == null) + .min(canonicalComparator); + if (matchingConnecting.isPresent() && complementConnecting.isPresent() && !matchingConnecting.equals(complementConnecting)) { + scopeStraddlingEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingConnecting.get(), complementConnecting.get())); + } + + // Compile the scope straddling equality expressions. + // scopeStraddlingExpressions couldn't be pushed to either side, + // therefore there needs to be an equality generated with + // one of the scopes (either matching or complement). + List straddlingExpressions = new ArrayList<>(); + if (matchingCanonical != null) { + straddlingExpressions.add(matchingCanonical); + } + else if (complementCanonical != null) { + straddlingExpressions.add(complementCanonical); + } + straddlingExpressions.addAll(scopeStraddlingExpressions); + Expression connectingCanonical = getCanonical(straddlingExpressions.stream()); if (connectingCanonical != null) { - connectingExpressions.stream() + straddlingExpressions.stream() .filter(expression -> !expression.equals(connectingCanonical)) .map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, connectingCanonical, expression)) .forEach(scopeStraddlingEqualities::add); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java index a1803ba70ec2..5fb02b62ae8f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/AbstractPredicatePushdownTest.java @@ -13,6 +13,7 @@ */ package io.trino.sql.planner; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.sql.planner.assertions.BasePlanTest; @@ -39,6 +40,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -465,6 +467,66 @@ WITH t(a) AS (VALUES 'a', 'b') output(values("field", "field_0"))); } + @Test + public void testSimplifyNonInferrableInheritedPredicate() + { + assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = regionkey AND regionkey = 5) a, nation b WHERE a.nationkey = b.nationkey AND a.nationkey + 11 > 15", + output( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("((L_NATIONKEY = L_REGIONKEY) AND (L_REGIONKEY = BIGINT '5'))", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); + } + + @Test + public void testDoesNotCreatePredicateFromInferredPredicate() + { + assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria("L_NATIONKEY", "R_NATIONKEY") + .left( + filter("true", // DF filter + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) + .right( + anyTree( + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))); + + assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("L_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); + } + + @Test + public void testSimplifiesStraddlingPredicate() + { + assertPlan("SELECT * FROM (SELECT * FROM NATION WHERE nationkey = 5) a JOIN nation b ON a.nationkey = b.nationkey AND a.nationkey = a.regionkey + b.regionkey", + output( + filter("L_REGIONKEY + R_REGIONKEY = BIGINT '5'", + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("L_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey", "R_REGIONKEY", "regionkey"))))))))); + } + protected Session noSemiJoinRewrite() { return Session.builder(getQueryRunner().getDefaultSession()) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java index b3174c7a48a7..2c03203834d5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEqualityInference.java @@ -65,6 +65,29 @@ public class TestEqualityInference private final TestingFunctionResolution functionResolution = new TestingFunctionResolution(); private final Metadata metadata = functionResolution.getMetadata(); + @Test + public void testDoesNotInferRedundantStraddlingPredicates() + { + EqualityInference inference = new EqualityInference( + metadata, + equals("a1", "b1"), + equals(add(nameReference("a1"), number(1)), number(0)), + equals(nameReference("a2"), add(nameReference("a1"), number(2))), + equals(nameReference("a1"), add("a3", "b3")), + equals(nameReference("b2"), add("a4", "b4"))); + EqualityInference.EqualityPartition partition = inference.generateEqualitiesPartitionedBy(symbols("a1", "a2", "a3", "a4")); + assertThat(partition.getScopeEqualities()).containsExactly( + equals(number(0), add(nameReference("a1"), number(1))), + equals(nameReference("a2"), add(nameReference("a1"), number(2)))); + assertThat(partition.getScopeComplementEqualities()).containsExactly( + equals(number(0), add(nameReference("b1"), number(1)))); + // there shouldn't be equality a2 = b1 + 1 as it can be derived from a2 = a1 + 1, a1 = b1 + assertThat(partition.getScopeStraddlingEqualities()).containsExactly( + equals("a1", "b1"), + equals(nameReference("a1"), add("a3", "b3")), + equals(nameReference("b2"), add("a4", "b4"))); + } + @Test public void testTransitivity() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java index 1c4e627e4742..92fa340168ef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPredicatePushdownWithoutDynamicFilter.java @@ -13,6 +13,7 @@ */ package io.trino.sql.planner; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.sql.planner.plan.ExchangeNode; @@ -23,6 +24,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.sql.planner.assertions.PlanMatchPattern.output; import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; @@ -173,4 +175,31 @@ public void testNonStraddlingJoinExpression() anyTree( tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))))))); } + + @Override + @Test + public void testDoesNotCreatePredicateFromInferredPredicate() + { + assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria("L_NATIONKEY", "R_NATIONKEY") + .left( + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))) + .right( + anyTree( + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))); + + assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey", + output( + join(INNER, builder -> builder + .equiCriteria(ImmutableList.of()) + .left( + filter("L_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))) + .right( + anyTree( + filter("R_NATIONKEY = BIGINT '5'", + tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))))); + } }