Skip to content

Commit

Permalink
Do not generate redundant straddling predicates by equality inference
Browse files Browse the repository at this point in the history
When there are predicates like a1 = b1 and a2 = a1 + 1, then
equality inference would derive staddling predicate for
a2 = b1 + 1, which is redundant to a1 = b1, a2 = a1 + 1.
This commit makes sure that redundant straddling predicates
are not generated.
  • Loading branch information
sopel39 committed Dec 22, 2023
1 parent fb190ef commit d3da589
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -215,17 +214,34 @@ public EqualityPartition generateEqualitiesPartitionedBy(Set<Symbol> scope)
.forEach(scopeComplementEqualities::add);
}

// Compile the scope straddling equality expressions
List<Expression> connectingExpressions = new ArrayList<>();
connectingExpressions.add(matchingCanonical);
connectingExpressions.add(complementCanonical);
connectingExpressions.addAll(scopeStraddlingExpressions);
connectingExpressions = connectingExpressions.stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
Expression connectingCanonical = getCanonical(connectingExpressions.stream());
// Compile single equality between matching and complement scope.
// Only consider expressions that don't have derived expression in other scope.
// Otherwise, redundant equality would be generated.
Optional<Expression> matchingConnecting = scopeExpressions.stream()
.filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, symbol -> !scope.contains(symbol), false) == null)
.min(canonicalComparator);
Optional<Expression> complementConnecting = scopeComplementExpressions.stream()
.filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, scope::contains, false) == null)
.min(canonicalComparator);
if (matchingConnecting.isPresent() && complementConnecting.isPresent() && !matchingConnecting.equals(complementConnecting)) {
scopeStraddlingEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingConnecting.get(), complementConnecting.get()));
}

// Compile the scope straddling equality expressions.
// scopeStraddlingExpressions couldn't be pushed to either side,
// therefore there needs to be an equality generated with
// one of the scopes (either matching or complement).
List<Expression> straddlingExpressions = new ArrayList<>();
if (matchingCanonical != null) {
straddlingExpressions.add(matchingCanonical);
}
else if (complementCanonical != null) {
straddlingExpressions.add(complementCanonical);
}
straddlingExpressions.addAll(scopeStraddlingExpressions);
Expression connectingCanonical = getCanonical(straddlingExpressions.stream());
if (connectingCanonical != null) {
connectingExpressions.stream()
straddlingExpressions.stream()
.filter(expression -> !expression.equals(connectingCanonical))
.map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, connectingCanonical, expression))
.forEach(scopeStraddlingEqualities::add);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.sql.planner.assertions.BasePlanTest;
Expand All @@ -39,6 +40,7 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;

Expand Down Expand Up @@ -465,6 +467,66 @@ WITH t(a) AS (VALUES 'a', 'b')
output(values("field", "field_0")));
}

@Test
public void testSimplifyNonInferrableInheritedPredicate()
{
assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = regionkey AND regionkey = 5) a, nation b WHERE a.nationkey = b.nationkey AND a.nationkey + 11 > 15",
output(
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("((L_NATIONKEY = L_REGIONKEY) AND (L_REGIONKEY = BIGINT '5'))",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))));
}

@Test
public void testDoesNotCreatePredicateFromInferredPredicate()
{
assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria("L_NATIONKEY", "R_NATIONKEY")
.left(
filter("true", // DF filter
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))))
.right(
anyTree(
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))));

assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("L_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))));
}

@Test
public void testSimplifiesStraddlingPredicate()
{
assertPlan("SELECT * FROM (SELECT * FROM NATION WHERE nationkey = 5) a JOIN nation b ON a.nationkey = b.nationkey AND a.nationkey = a.regionkey + b.regionkey",
output(
filter("L_REGIONKEY + R_REGIONKEY = BIGINT '5'",
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("L_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey", "R_REGIONKEY", "regionkey")))))))));
}

protected Session noSemiJoinRewrite()
{
return Session.builder(getQueryRunner().getDefaultSession())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,29 @@ public class TestEqualityInference
private final TestingFunctionResolution functionResolution = new TestingFunctionResolution();
private final Metadata metadata = functionResolution.getMetadata();

@Test
public void testDoesNotInferRedundantStraddlingPredicates()
{
EqualityInference inference = new EqualityInference(
metadata,
equals("a1", "b1"),
equals(add(nameReference("a1"), number(1)), number(0)),
equals(nameReference("a2"), add(nameReference("a1"), number(2))),
equals(nameReference("a1"), add("a3", "b3")),
equals(nameReference("b2"), add("a4", "b4")));
EqualityInference.EqualityPartition partition = inference.generateEqualitiesPartitionedBy(symbols("a1", "a2", "a3", "a4"));
assertThat(partition.getScopeEqualities()).containsExactly(
equals(number(0), add(nameReference("a1"), number(1))),
equals(nameReference("a2"), add(nameReference("a1"), number(2))));
assertThat(partition.getScopeComplementEqualities()).containsExactly(
equals(number(0), add(nameReference("b1"), number(1))));
// there shouldn't be equality a2 = b1 + 1 as it can be derived from a2 = a1 + 1, a1 = b1
assertThat(partition.getScopeStraddlingEqualities()).containsExactly(
equals("a1", "b1"),
equals(nameReference("a1"), add("a3", "b3")),
equals(nameReference("b2"), add("a4", "b4")));
}

@Test
public void testTransitivity()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.sql.planner.plan.ExchangeNode;
Expand All @@ -23,6 +24,7 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter;
import static io.trino.sql.planner.assertions.PlanMatchPattern.join;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.output;
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
Expand Down Expand Up @@ -173,4 +175,31 @@ public void testNonStraddlingJoinExpression()
anyTree(
tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))))));
}

@Override
@Test
public void testDoesNotCreatePredicateFromInferredPredicate()
{
assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria("L_NATIONKEY", "R_NATIONKEY")
.left(
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))
.right(
anyTree(
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))));

assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("L_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))));
}
}

0 comments on commit d3da589

Please sign in to comment.