Skip to content

Commit

Permalink
Move TranslateExpressions above last UnaliasSymbolReferences
Browse files Browse the repository at this point in the history
  • Loading branch information
hellium01 authored and highker committed Aug 22, 2019
1 parent 0f2b3dd commit 9a0824d
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ public PlanOptimizers(
new RewriteSpatialPartitioningAggregation(metadata)))
.build()),
simplifyOptimizer,
new UnaliasSymbolReferences(),
new UnaliasSymbolReferences(metadata.getFunctionManager()),
new IterativeOptimizer(
ruleStats,
statsCalculator,
Expand Down Expand Up @@ -368,7 +368,7 @@ public PlanOptimizers(
inlineProjections,
simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations
projectionPushDown,
new UnaliasSymbolReferences(), // Run again because predicate pushdown and projection pushdown might add more projections
new UnaliasSymbolReferences(metadata.getFunctionManager()), // Run again because predicate pushdown and projection pushdown might add more projections
new PruneUnreferencedOutputs(), // Make sure to run this before index join. Filtered projections may not have all the columns.
new IndexJoinOptimizer(metadata), // Run this after projections and filters have been fully simplified and pushed down
new IterativeOptimizer(
Expand Down Expand Up @@ -477,8 +477,6 @@ public PlanOptimizers(
builder.add(simplifyOptimizer); // Should be always run after PredicatePushDown
builder.add(projectionPushDown);
builder.add(inlineProjections);
builder.add(new UnaliasSymbolReferences()); // Run unalias after merging projections to simplify projections more efficiently
builder.add(new PruneUnreferencedOutputs());

// TODO: move this before optimization if possible!!
// Replace all expressions with row expressions
Expand All @@ -489,6 +487,8 @@ public PlanOptimizers(
new TranslateExpressions(metadata, sqlParser).rules()));
// After this point, all planNodes should not contain OriginalExpression

builder.add(new UnaliasSymbolReferences(metadata.getFunctionManager())); // Run unalias after merging projections to simplify projections more efficiently
builder.add(new PruneUnreferencedOutputs());
builder.add(new IterativeOptimizer(
ruleStats,
statsCalculator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
Expand All @@ -85,7 +84,6 @@
import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
Expand Down Expand Up @@ -195,10 +193,18 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Set<VariableReferenceExp
{
Set<VariableReferenceExpression> expectedFilterInputs = new HashSet<>();
if (node.getFilter().isPresent()) {
expectedFilterInputs = ImmutableSet.<VariableReferenceExpression>builder()
.addAll(VariablesExtractor.extractUnique(castToExpression(node.getFilter().get()), variableAllocator.getTypes()))
.addAll(context.get())
.build();
if (isExpression(node.getFilter().get())) {
expectedFilterInputs = ImmutableSet.<VariableReferenceExpression>builder()
.addAll(VariablesExtractor.extractUnique(castToExpression(node.getFilter().get()), variableAllocator.getTypes()))
.addAll(context.get())
.build();
}
else {
expectedFilterInputs = ImmutableSet.<VariableReferenceExpression>builder()
.addAll(VariablesExtractor.extractUnique(node.getFilter().get()))
.addAll(context.get())
.build();
}
}

ImmutableSet.Builder<VariableReferenceExpression> leftInputsBuilder = ImmutableSet.builder();
Expand Down Expand Up @@ -451,10 +457,19 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Set<VariableRe
@Override
public PlanNode visitFilter(FilterNode node, RewriteContext<Set<VariableReferenceExpression>> context)
{
Set<VariableReferenceExpression> expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
.addAll(VariablesExtractor.extractUnique(castToExpression(node.getPredicate()), variableAllocator.getTypes()))
.addAll(context.get())
.build();
Set<VariableReferenceExpression> expectedInputs;
if (isExpression(node.getPredicate())) {
expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
.addAll(VariablesExtractor.extractUnique(castToExpression(node.getPredicate()), variableAllocator.getTypes()))
.addAll(context.get())
.build();
}
else {
expectedInputs = ImmutableSet.<VariableReferenceExpression>builder()
.addAll(VariablesExtractor.extractUnique(node.getPredicate()))
.addAll(context.get())
.build();
}

PlanNode source = context.rewrite(node.getSource(), expectedInputs);

Expand Down Expand Up @@ -540,7 +555,12 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Set<VariableRefere
Assignments.Builder builder = Assignments.builder();
node.getAssignments().forEach((variable, expression) -> {
if (context.get().contains(variable)) {
expectedInputs.addAll(VariablesExtractor.extractUnique(castToExpression(expression), variableAllocator.getTypes()));
if (isExpression(expression)) {
expectedInputs.addAll(VariablesExtractor.extractUnique(castToExpression(expression), variableAllocator.getTypes()));
}
else {
expectedInputs.addAll(VariablesExtractor.extractUnique(expression));
}
builder.put(variable, expression);
}
});
Expand Down Expand Up @@ -798,10 +818,15 @@ public PlanNode visitApply(ApplyNode node, RewriteContext<Set<VariableReferenceE
Assignments.Builder subqueryAssignments = Assignments.builder();
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getSubqueryAssignments().getMap().entrySet()) {
VariableReferenceExpression output = entry.getKey();
Expression expression = castToExpression(entry.getValue());
RowExpression expression = entry.getValue();
if (context.get().contains(output)) {
subqueryAssignmentsVariablesBuilder.addAll(VariablesExtractor.extractUnique(expression, variableAllocator.getTypes()));
subqueryAssignments.put(output, castToRowExpression(expression));
if (isExpression(expression)) {
subqueryAssignmentsVariablesBuilder.addAll(VariablesExtractor.extractUnique(castToExpression(expression), variableAllocator.getTypes()));
}
else {
subqueryAssignmentsVariablesBuilder.addAll(VariablesExtractor.extractUnique(expression));
}
subqueryAssignments.put(output, expression);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.Session;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.LimitNode;
Expand All @@ -31,6 +32,7 @@
import com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.AggregationNode;
Expand Down Expand Up @@ -67,7 +69,8 @@
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
Expand Down Expand Up @@ -115,6 +118,13 @@
public class UnaliasSymbolReferences
implements PlanOptimizer
{
private final FunctionManager functionManager;

public UnaliasSymbolReferences(FunctionManager functionManager)
{
this.functionManager = requireNonNull(functionManager, "functionManager is null");
}

@Override
public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
Expand All @@ -124,18 +134,20 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Pla
requireNonNull(variableAllocator, "variableAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");

return SimplePlanRewriter.rewriteWith(new Rewriter(types), plan);
return SimplePlanRewriter.rewriteWith(new Rewriter(types, functionManager), plan);
}

private static class Rewriter
extends SimplePlanRewriter<Void>
{
private final Map<String, String> mapping = new HashMap<>();
private final TypeProvider types;
private final RowExpressionDeterminismEvaluator determinismEvaluator;

private Rewriter(TypeProvider types)
private Rewriter(TypeProvider types, FunctionManager functionManager)
{
this.types = types;
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionManager);
}

@Override
Expand Down Expand Up @@ -232,7 +244,7 @@ private List<RowExpression> canonicalizeCallExpression(CallExpression callExpres
// TODO: arguments will be pure RowExpression once we introduce subquery expression for RowExpression.
return callExpression.getArguments()
.stream()
.map(argument -> castToRowExpression(canonicalize(castToExpression(argument))))
.map(this::canonicalize)
.collect(toImmutableList());
}

Expand Down Expand Up @@ -372,12 +384,7 @@ public PlanNode visitValues(ValuesNode node, RewriteContext<Void> context)
{
List<List<RowExpression>> canonicalizedRows = node.getRows().stream()
.map(rowExpressions -> rowExpressions.stream()
.map(rowExpression -> {
if (isExpression(rowExpression)) {
return castToRowExpression(canonicalize(castToExpression(rowExpression)));
}
return rowExpression;
})
.map(this::canonicalize)
.collect(toImmutableList()))
.collect(toImmutableList());
List<VariableReferenceExpression> canonicalizedOutputVariables = canonicalizeAndDistinct(node.getOutputVariables());
Expand Down Expand Up @@ -434,7 +441,7 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
{
PlanNode source = context.rewrite(node.getSource());

return new FilterNode(node.getId(), source, castToRowExpression(canonicalize(castToExpression(node.getPredicate()))));
return new FilterNode(node.getId(), source, canonicalize(node.getPredicate()));
}

@Override
Expand Down Expand Up @@ -515,7 +522,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
PlanNode right = context.rewrite(node.getRight());

List<JoinNode.EquiJoinClause> canonicalCriteria = canonicalizeJoinCriteria(node.getCriteria());
Optional<Expression> canonicalFilter = node.getFilter().map(OriginalExpressionUtils::castToExpression).map(this::canonicalize);
Optional<RowExpression> canonicalFilter = node.getFilter().map(this::canonicalize);
Optional<VariableReferenceExpression> canonicalLeftHashVariable = canonicalize(node.getLeftHashVariable());
Optional<VariableReferenceExpression> canonicalRightHashVariable = canonicalize(node.getRightHashVariable());

Expand All @@ -533,7 +540,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
right,
canonicalCriteria,
canonicalizeAndDistinct(node.getOutputVariables()),
canonicalFilter.map(OriginalExpressionUtils::castToRowExpression),
canonicalFilter,
canonicalLeftHashVariable,
canonicalRightHashVariable,
node.getDistributionType());
Expand Down Expand Up @@ -563,7 +570,7 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext<Void> cont
PlanNode left = context.rewrite(node.getLeft());
PlanNode right = context.rewrite(node.getRight());

return new SpatialJoinNode(node.getId(), node.getType(), left, right, canonicalizeAndDistinct(node.getOutputVariables()), castToRowExpression(canonicalize(castToExpression(node.getFilter()))), canonicalize(node.getLeftPartitionVariable()), canonicalize(node.getRightPartitionVariable()), node.getKdbTree());
return new SpatialJoinNode(node.getId(), node.getType(), left, right, canonicalizeAndDistinct(node.getOutputVariables()), canonicalize(node.getFilter()), canonicalize(node.getLeftPartitionVariable()), canonicalize(node.getRightPartitionVariable()), node.getKdbTree());
}

@Override
Expand Down Expand Up @@ -630,19 +637,25 @@ private void map(VariableReferenceExpression variable, VariableReferenceExpressi

private Assignments canonicalize(Assignments oldAssignments)
{
Map<Expression, VariableReferenceExpression> computedExpressions = new HashMap<>();
Map<RowExpression, VariableReferenceExpression> computedExpressions = new HashMap<>();
Assignments.Builder assignments = Assignments.builder();
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : oldAssignments.getMap().entrySet()) {
Expression expression = canonicalize(castToExpression(entry.getValue()));

if (expression instanceof SymbolReference) {
RowExpression expression = canonicalize(entry.getValue());
if (expression instanceof VariableReferenceExpression) {
// Always map a trivial variable projection
VariableReferenceExpression variable = (VariableReferenceExpression) expression;
if (!variable.getName().equals(entry.getKey().getName())) {
map(entry.getKey(), variable);
}
}
else if (isExpression(expression) && castToExpression(expression) instanceof SymbolReference) {
// Always map a trivial symbol projection
VariableReferenceExpression variable = new VariableReferenceExpression(((SymbolReference) expression).getName(), types.get(expression));
VariableReferenceExpression variable = new VariableReferenceExpression(Symbol.from(castToExpression(expression)).getName(), types.get(castToExpression(expression)));
if (!variable.getName().equals(entry.getKey().getName())) {
map(entry.getKey(), variable);
}
}
else if (ExpressionDeterminismEvaluator.isDeterministic(expression) && !(expression instanceof NullLiteral)) {
else if (!isNull(expression) && isDeterministic(expression)) {
// Try to map same deterministic expressions within a projection into the same symbol
// Omit NullLiterals since those have ambiguous types
VariableReferenceExpression computedVariable = computedExpressions.get(expression);
Expand All @@ -658,11 +671,27 @@ else if (ExpressionDeterminismEvaluator.isDeterministic(expression) && !(express
}

VariableReferenceExpression canonical = canonicalize(entry.getKey());
assignments.put(canonical, castToRowExpression(expression));
assignments.put(canonical, expression);
}
return assignments.build();
}

private boolean isDeterministic(RowExpression expression)
{
if (isExpression(expression)) {
return ExpressionDeterminismEvaluator.isDeterministic(castToExpression(expression));
}
return determinismEvaluator.isDeterministic(expression);
}

private static boolean isNull(RowExpression expression)
{
if (isExpression(expression)) {
return castToExpression(expression) instanceof NullLiteral;
}
return Expressions.isNull(expression);
}

private Symbol canonicalize(Symbol symbol)
{
String canonical = symbol.getName();
Expand All @@ -689,6 +718,15 @@ private Optional<VariableReferenceExpression> canonicalize(Optional<VariableRefe
return Optional.empty();
}

private RowExpression canonicalize(RowExpression value)
{
if (isExpression(value)) {
// TODO remove once all UnaliasSymbolReference are above translateExpressions
return castToRowExpression(canonicalize(castToExpression(value)));
}
return RowExpressionVariableInliner.inlineVariables(this::canonicalize, value);
}

private Expression canonicalize(Expression value)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ public static ConstantExpression constantNull(Type type)
return new ConstantExpression(null, type);
}

public static boolean isNull(RowExpression expression)
{
return expression instanceof ConstantExpression && ((ConstantExpression) expression).isNull();
}

public static CallExpression call(String displayName, FunctionHandle functionHandle, Type returnType, RowExpression... arguments)
{
return new CallExpression(displayName, functionHandle, returnType, Arrays.asList(arguments));
Expand Down
Loading

0 comments on commit 9a0824d

Please sign in to comment.