Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not generate redundant straddling predicates by equality inference #16520

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -215,17 +214,34 @@ public EqualityPartition generateEqualitiesPartitionedBy(Set<Symbol> scope)
.forEach(scopeComplementEqualities::add);
}

// Compile the scope straddling equality expressions
List<Expression> connectingExpressions = new ArrayList<>();
connectingExpressions.add(matchingCanonical);
connectingExpressions.add(complementCanonical);
connectingExpressions.addAll(scopeStraddlingExpressions);
connectingExpressions = connectingExpressions.stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
Expression connectingCanonical = getCanonical(connectingExpressions.stream());
// Compile single equality between matching and complement scope.
// Only consider expressions that don't have derived expression in other scope.
// Otherwise, redundant equality would be generated.
Optional<Expression> matchingConnecting = scopeExpressions.stream()
.filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, symbol -> !scope.contains(symbol), false) == null)
.min(canonicalComparator);
Optional<Expression> complementConnecting = scopeComplementExpressions.stream()
.filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || rewrite(expression, scope::contains, false) == null)
.min(canonicalComparator);
if (matchingConnecting.isPresent() && complementConnecting.isPresent() && !matchingConnecting.equals(complementConnecting)) {
scopeStraddlingEqualities.add(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingConnecting.get(), complementConnecting.get()));
}

// Compile the scope straddling equality expressions.
// scopeStraddlingExpressions couldn't be pushed to either side,
// therefore there needs to be an equality generated with
// one of the scopes (either matching or complement).
List<Expression> straddlingExpressions = new ArrayList<>();
if (matchingCanonical != null) {
straddlingExpressions.add(matchingCanonical);
}
else if (complementCanonical != null) {
straddlingExpressions.add(complementCanonical);
}
Comment on lines +235 to +240
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if both are non-null? What are the implications of adding one but not the other?

Copy link
Member Author

@sopel39 sopel39 Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you have three groups of expressions:

scope
complement
straddling (these are guaranteed to have symbols from from scope and non-scope)

They are all equal.
Therefore you need to generate equalities

  • within scope
  • within complement
  • within straddling
  • between scope<->complement
  • between ONE OF (scope, complement) <-> straddling

let's say:

scope = (l.x; l.y)
complement  = (r.x; r.y)
straddling = (l.i + r.j; l.m + r.n)

Then you generate:

for scope: l.x = l.y
for complement: r.x = r.y
for straddling: l.i + r.j = l.m + r.n

You also need to generate one canonical equality between scope and complement:

l.x = r.y

But you also need to connect straddling group with the rest of expressions somehow. The way you do it is by choosing canonical expression from EITHER scope or complement and generate equality with any straddling expression, e.g:

l.x = l.i + r.j 

This way yo generate all needed equalities, but not more.

If you were to choose canonical expression from both scope and complement, then you would generate redundant equality. Specifically this is important for constants, e.g:

equalitySet = [1, l.r, r.y, l.i + r.j]
scope = [1, l.r]
complement = [1, r.y]
straddling = [l.i + r.j]

generated equalities:

1 = l.r; 1 = r.y; 1 = l.i + r.j

If you were to choose canonical expression from both scope and complement, then it would generate equalities:

1 = l.r; 1 = r.y; 1 = 1; 1 = l.i + r.j

because 1 is canonical on both scope and complement side.

straddlingExpressions.addAll(scopeStraddlingExpressions);
Expression connectingCanonical = getCanonical(straddlingExpressions.stream());
if (connectingCanonical != null) {
connectingExpressions.stream()
straddlingExpressions.stream()
.filter(expression -> !expression.equals(connectingCanonical))
.map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, connectingCanonical, expression))
.forEach(scopeStraddlingEqualities::add);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1008,16 +1008,10 @@ private InnerJoinPushDownResult processInnerJoin(
ImmutableSet<Symbol> leftScope = ImmutableSet.copyOf(leftSymbols);
ImmutableSet<Symbol> rightScope = ImmutableSet.copyOf(rightSymbols);

// Attempt to simplify the effective left/right predicates with the predicate we're pushing down
// This, effectively, inlines any constants derived from such predicate
EqualityInference predicateInference = new EqualityInference(metadata, inheritedPredicate);
Expression simplifiedLeftEffectivePredicate = predicateInference.rewrite(leftEffectivePredicate, leftScope);
Expression simplifiedRightEffectivePredicate = predicateInference.rewrite(rightEffectivePredicate, rightScope);

// Generate equality inferences
EqualityInference allInference = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, simplifiedLeftEffectivePredicate, simplifiedRightEffectivePredicate);
EqualityInference allInferenceWithoutLeftInferred = new EqualityInference(metadata, inheritedPredicate, rightEffectivePredicate, joinPredicate, simplifiedRightEffectivePredicate);
EqualityInference allInferenceWithoutRightInferred = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, joinPredicate, simplifiedLeftEffectivePredicate);
EqualityInference allInference = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate);
EqualityInference allInferenceWithoutLeftInferred = new EqualityInference(metadata, inheritedPredicate, rightEffectivePredicate, joinPredicate);
EqualityInference allInferenceWithoutRightInferred = new EqualityInference(metadata, inheritedPredicate, leftEffectivePredicate, joinPredicate);

// Add equalities from the inference back in
leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(leftScope).getScopeEqualities());
Expand All @@ -1043,13 +1037,13 @@ private InnerJoinPushDownResult processInnerJoin(
});

// See if we can push the right effective predicate to the left side
EqualityInference.nonInferrableConjuncts(metadata, simplifiedRightEffectivePredicate)
EqualityInference.nonInferrableConjuncts(metadata, rightEffectivePredicate)
.map(conjunct -> allInference.rewrite(conjunct, leftScope))
.filter(Objects::nonNull)
.forEach(leftPushDownConjuncts::add);

// See if we can push the left effective predicate to the right side
EqualityInference.nonInferrableConjuncts(metadata, simplifiedLeftEffectivePredicate)
EqualityInference.nonInferrableConjuncts(metadata, leftEffectivePredicate)
.map(conjunct -> allInference.rewrite(conjunct, rightScope))
.filter(Objects::nonNull)
.forEach(rightPushDownConjuncts::add);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.sql.planner.assertions.BasePlanTest;
Expand All @@ -39,6 +40,7 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;

Expand Down Expand Up @@ -465,6 +467,66 @@ WITH t(a) AS (VALUES 'a', 'b')
output(values("field", "field_0")));
}

@Test
public void testSimplifyNonInferrableInheritedPredicate()
{
assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = regionkey AND regionkey = 5) a, nation b WHERE a.nationkey = b.nationkey AND a.nationkey + 11 > 15",
output(
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("((L_NATIONKEY = L_REGIONKEY) AND (L_REGIONKEY = BIGINT '5'))",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))));
}

@Test
public void testDoesNotCreatePredicateFromInferredPredicate()
{
assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria("L_NATIONKEY", "R_NATIONKEY")
.left(
filter("true", // DF filter
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))))
.right(
anyTree(
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))));

assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("L_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))));
}

@Test
public void testSimplifiesStraddlingPredicate()
{
assertPlan("SELECT * FROM (SELECT * FROM NATION WHERE nationkey = 5) a JOIN nation b ON a.nationkey = b.nationkey AND a.nationkey = a.regionkey + b.regionkey",
output(
filter("L_REGIONKEY + R_REGIONKEY = BIGINT '5'",
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("L_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey", "L_REGIONKEY", "regionkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey", "R_REGIONKEY", "regionkey")))))))));
}

protected Session noSemiJoinRewrite()
{
return Session.builder(getQueryRunner().getDefaultSession())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,29 @@ public class TestEqualityInference
private final TestingFunctionResolution functionResolution = new TestingFunctionResolution();
private final Metadata metadata = functionResolution.getMetadata();

@Test
public void testDoesNotInferRedundantStraddlingPredicates()
{
EqualityInference inference = new EqualityInference(
metadata,
equals("a1", "b1"),
equals(add(nameReference("a1"), number(1)), number(0)),
equals(nameReference("a2"), add(nameReference("a1"), number(2))),
equals(nameReference("a1"), add("a3", "b3")),
equals(nameReference("b2"), add("a4", "b4")));
EqualityInference.EqualityPartition partition = inference.generateEqualitiesPartitionedBy(symbols("a1", "a2", "a3", "a4"));
assertThat(partition.getScopeEqualities()).containsExactly(
equals(number(0), add(nameReference("a1"), number(1))),
equals(nameReference("a2"), add(nameReference("a1"), number(2))));
assertThat(partition.getScopeComplementEqualities()).containsExactly(
equals(number(0), add(nameReference("b1"), number(1))));
// there shouldn't be equality a2 = b1 + 1 as it can be derived from a2 = a1 + 1, a1 = b1
assertThat(partition.getScopeStraddlingEqualities()).containsExactly(
equals("a1", "b1"),
equals(nameReference("a1"), add("a3", "b3")),
equals(nameReference("b2"), add("a4", "b4")));
}

@Test
public void testTransitivity()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.sql.planner.plan.ExchangeNode;
Expand All @@ -23,6 +24,7 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter;
import static io.trino.sql.planner.assertions.PlanMatchPattern.join;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.output;
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
Expand Down Expand Up @@ -173,4 +175,31 @@ public void testNonStraddlingJoinExpression()
anyTree(
tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")))))));
}

@Override
@Test
public void testDoesNotCreatePredicateFromInferredPredicate()
{
assertPlan("SELECT * FROM (SELECT *, nationkey + 1 as nationkey2 FROM nation) a JOIN nation b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria("L_NATIONKEY", "R_NATIONKEY")
.left(
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey")))
.right(
anyTree(
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey")))))));

assertPlan("SELECT * FROM (SELECT * FROM nation WHERE nationkey = 5) a JOIN (SELECT * FROM nation WHERE nationkey = 5) b ON a.nationkey = b.nationkey",
output(
join(INNER, builder -> builder
.equiCriteria(ImmutableList.of())
.left(
filter("L_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("L_NATIONKEY", "nationkey"))))
.right(
anyTree(
filter("R_NATIONKEY = BIGINT '5'",
tableScan("nation", ImmutableMap.of("R_NATIONKEY", "nationkey"))))))));
}
}
Loading