Skip to content

Commit

Permalink
Remove pruning of correlation list from project-off rule
Browse files Browse the repository at this point in the history
Do not prune correlation list in PruneCorrelatedJoinColumns rule.
Pruning of correlation list should not be responsibility of a project-off rule.
It does not require or use outer context.
It is now done in the rule PruneCorrelatedJoinCorrelation.
Moving this functionality to another rule was necessary
to complete migration of PruneUnreferencedOutputs optimizer
to iterative rules.
  • Loading branch information
kasiafi authored and martint committed Sep 17, 2020
1 parent b635826 commit 53c9d38
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.PlanNode;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Sets.intersection;
import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique;
import static io.prestosql.sql.planner.iterative.rule.Util.restrictOutputs;
Expand All @@ -47,14 +45,8 @@
* - it is not a referenced output symbol,
* - it is not present in join filetr.
* <p>
* A symbol can be removed from the correlation list, when
* it is no longer present in the subquery.
* <p>
* Note: this rule does not remove any symbols from the subquery.
* However, the correlated symbol might have been removed from
* the subquery by another rule. This rule checks it so that it can
* update the correlation list and take the advantage of
* pruning the symbol if it is not referenced.
* Note: this rule does not remove any symbols from the correlation list.
* This is responsibility of PruneCorrelatedJoinCorrelation rule.
* <p>
* Transforms:
* <pre>
Expand All @@ -69,9 +61,9 @@
* <pre>
* - Project (a, c)
* - CorrelatedJoin
* correlation: []
* correlation: [corr]
* filter: a > d
* - Project (a)
* - Project (a, corr)
* - Input (a, b, corr)
* - Project (c, d)
* - Subquery (c, d, e)
Expand Down Expand Up @@ -103,15 +95,9 @@ protected Optional<PlanNode> pushDownProjectOff(Context context, CorrelatedJoinN
}
}

// extract actual correlation symbols
Set<Symbol> subquerySymbols = extractUnique(subquery, context.getLookup());
List<Symbol> newCorrelation = correlatedJoinNode.getCorrelation().stream()
.filter(subquerySymbols::contains)
.collect(toImmutableList());

Set<Symbol> referencedAndCorrelationSymbols = ImmutableSet.<Symbol>builder()
.addAll(referencedOutputs)
.addAll(newCorrelation)
.addAll(correlatedJoinNode.getCorrelation())
.build();

// remove unused input node, retain subquery
Expand All @@ -137,21 +123,19 @@ protected Optional<PlanNode> pushDownProjectOff(Context context, CorrelatedJoinN

Set<Symbol> referencedAndFilterAndCorrelationSymbols = ImmutableSet.<Symbol>builder()
.addAll(referencedAndFilterSymbols)
.addAll(newCorrelation)
.addAll(correlatedJoinNode.getCorrelation())
.build();

Optional<PlanNode> newInput = restrictOutputs(context.getIdAllocator(), input, referencedAndFilterAndCorrelationSymbols);

boolean pruned = newSubquery.isPresent()
|| newInput.isPresent()
|| newCorrelation.size() < correlatedJoinNode.getCorrelation().size();
boolean pruned = newSubquery.isPresent() || newInput.isPresent();

if (pruned) {
return Optional.of(new CorrelatedJoinNode(
correlatedJoinNode.getId(),
newInput.orElse(input),
newSubquery.orElse(subquery),
newCorrelation,
correlatedJoinNode.getCorrelation(),
correlatedJoinNode.getType(),
correlatedJoinNode.getFilter(),
correlatedJoinNode.getOriginSubquery()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,12 @@ public void testRemoveUnusedCorrelatedJoinNode()
tester().assertThat(new PruneCorrelatedJoinColumns())
.on(p -> {
Symbol a = p.symbol("a");
Symbol correlationSymbol = p.symbol("correlation_symbol");
Symbol b = p.symbol("b");
return p.project(
Assignments.identity(b),
p.correlatedJoin(
ImmutableList.of(correlationSymbol),
p.values(1, a, correlationSymbol),
ImmutableList.of(),
p.values(1, a),
p.values(b)));
})
.matches(
Expand Down Expand Up @@ -181,34 +180,7 @@ public void testPruneUnreferencedInputSymbol()
}

@Test
public void testRemoveSymbolFromCorrelationList()
{
// symbol is removed from the correlation list, but it cannot be pruned because it's present in the join filter
tester().assertThat(new PruneCorrelatedJoinColumns())
.on(p -> {
Symbol a = p.symbol("a");
Symbol correlationSymbol = p.symbol("correlation_symbol");
Symbol b = p.symbol("b");
return p.project(
Assignments.identity(a, b),
p.correlatedJoin(
ImmutableList.of(correlationSymbol),
p.values(a, correlationSymbol),
LEFT,
new ComparisonExpression(GREATER_THAN, b.toSymbolReference(), correlationSymbol.toSymbolReference()),
p.values(b)));
})
.matches(
project(
ImmutableMap.of("b", PlanMatchPattern.expression("b")),
correlatedJoin(
ImmutableList.of(),
values("a", "correlation_symbol"),
values("b"))));
}

@Test
public void testPruneUnreferencedCorrelationSymbol()
public void testDoNotPruneUnreferencedCorrelationSymbol()
{
tester().assertThat(new PruneCorrelatedJoinColumns())
.on(p -> {
Expand All @@ -224,15 +196,7 @@ public void testPruneUnreferencedCorrelationSymbol()
TRUE_LITERAL,
p.values(b)));
})
.matches(
project(
ImmutableMap.of("b", PlanMatchPattern.expression("b")),
correlatedJoin(
ImmutableList.of(),
project(
ImmutableMap.of("a", PlanMatchPattern.expression("a")),
values("a", "correlation_symbol")),
values("b"))));
.doesNotFire();
}

@Test
Expand Down

0 comments on commit 53c9d38

Please sign in to comment.