From 0f2b3dd972fc8a7d589273fcf94975a8ba303fc0 Mon Sep 17 00:00:00 2001 From: Yi He Date: Wed, 14 Aug 2019 17:16:20 -0700 Subject: [PATCH] Move TranslateExpressions above last InlineProjections --- .../presto/sql/planner/PlanOptimizers.java | 24 ++-- .../planner/RowExpressionVariableInliner.java | 14 ++- .../iterative/rule/InlineProjections.java | 103 ++++++++++++++---- .../sql/planner/plan/AssignmentUtils.java | 9 +- .../iterative/rule/TestInlineProjections.java | 51 ++++++++- 5 files changed, 157 insertions(+), 44 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 202a58d1fb56..aec7bbe1e054 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -232,7 +232,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of( - new InlineProjections(), + new InlineProjections(metadata.getFunctionManager()), new RemoveRedundantIdentityProjections())); IterativeOptimizer projectionPushDown = new IterativeOptimizer( @@ -347,7 +347,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of( - new InlineProjections(), + new InlineProjections(metadata.getFunctionManager()), new RemoveRedundantIdentityProjections(), new TransformCorrelatedSingleRowSubqueryToProject())), new CheckSubqueryNodesAreRewritten(), @@ -441,7 +441,7 @@ public PlanOptimizers( ImmutableSet.>builder() .add(new RemoveRedundantIdentityProjections()) .addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, sqlParser).rules()) - .add(new InlineProjections()) + .add(new InlineProjections(metadata.getFunctionManager())) .build())); if (!forceSingleNode) { @@ -480,24 +480,24 @@ public PlanOptimizers( builder.add(new UnaliasSymbolReferences()); // Run unalias after merging projections to simplify projections more efficiently builder.add(new PruneUnreferencedOutputs()); + // TODO: move this before optimization if possible!! + // Replace all expressions with row expressions builder.add(new IterativeOptimizer( ruleStats, statsCalculator, costCalculator, - ImmutableSet.>builder() - .add(new RemoveRedundantIdentityProjections()) - .add(new PushRemoteExchangeThroughAssignUniqueId()) - .add(new InlineProjections()) - .build())); + new TranslateExpressions(metadata, sqlParser).rules())); + // After this point, all planNodes should not contain OriginalExpression - // TODO: move this before optimization if possible!! - // Replace all expressions with row expressions builder.add(new IterativeOptimizer( ruleStats, statsCalculator, costCalculator, - new TranslateExpressions(metadata, sqlParser).rules())); - // After this point, all planNodes should not contain OriginalExpression + ImmutableSet.>builder() + .add(new RemoveRedundantIdentityProjections()) + .add(new PushRemoteExchangeThroughAssignUniqueId()) + .add(new InlineProjections(metadata.getFunctionManager())) + .build())); // Optimizers above this don't understand local exchanges, so be careful moving this. builder.add(new AddLocalExchanges(metadata, sqlParser)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionVariableInliner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionVariableInliner.java index 30867745b446..ce531cbfbcf4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionVariableInliner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionVariableInliner.java @@ -22,6 +22,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -30,23 +31,28 @@ public final class RowExpressionVariableInliner extends RowExpressionRewriter { private final Set excludedNames = new HashSet<>(); - private final Map mapping; + private final Function mapping; - private RowExpressionVariableInliner(Map mapping) + private RowExpressionVariableInliner(Function mapping) { this.mapping = mapping; } - public static RowExpression inlineVariables(Map mapping, RowExpression expression) + public static RowExpression inlineVariables(Function mapping, RowExpression expression) { return RowExpressionTreeRewriter.rewriteWith(new RowExpressionVariableInliner(mapping), expression); } + public static RowExpression inlineVariables(Map mapping, RowExpression expression) + { + return inlineVariables(mapping::get, expression); + } + @Override public RowExpression rewriteVariableReference(VariableReferenceExpression node, Void context, RowExpressionTreeRewriter treeRewriter) { if (!excludedNames.contains(node.getName())) { - RowExpression result = mapping.get(node); + RowExpression result = mapping.apply(node); checkState(result != null, "Cannot resolve symbol %s", node.getName()); return result; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index deea53cf63ef..3f1b53e11626 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -16,15 +16,21 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.DefaultRowExpressionTraversalVisitor; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.ExpressionVariableInliner; +import com.facebook.presto.sql.planner.RowExpressionVariableInliner; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.Assignments.Builder; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Literal; @@ -34,6 +40,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -46,6 +53,7 @@ import static com.facebook.presto.sql.planner.plan.Patterns.source; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static java.util.stream.Collectors.toSet; /** @@ -62,6 +70,13 @@ public class InlineProjections private static final Pattern PATTERN = project() .with(source().matching(project().capturedAs(CHILD))); + private final FunctionResolution functionResolution; + + public InlineProjections(FunctionManager functionManager) + { + this.functionResolution = new FunctionResolution(functionManager); + } + @Override public Pattern getPattern() { @@ -84,18 +99,22 @@ public Result apply(ProjectNode parent, Captures captures, Context context) .entrySet().stream() .collect(Collectors.toMap( Map.Entry::getKey, - entry -> castToRowExpression(inlineReferences(castToExpression(entry.getValue()), assignments, context.getVariableAllocator().getTypes())))); + entry -> inlineReferences(entry.getValue(), assignments, context.getVariableAllocator().getTypes()))); // Synthesize identity assignments for the inputs of expressions that were inlined // to place in the child projection. // If all assignments end up becoming identity assignments, they'll get pruned by // other rules + boolean allTranslated = child.getAssignments().entrySet() + .stream() + .map(Map.Entry::getValue) + .noneMatch(OriginalExpressionUtils::isExpression); + Set inputs = child.getAssignments() .entrySet().stream() .filter(entry -> targets.contains(entry.getKey())) .map(Map.Entry::getValue) - .map(OriginalExpressionUtils::castToExpression) - .flatMap(entry -> VariablesExtractor.extractAll(entry, context.getVariableAllocator().getTypes()).stream()) + .flatMap(expression -> extractDependencies(expression, context.getVariableAllocator().getTypes()).stream()) .collect(toSet()); Builder childAssignments = Assignments.builder(); @@ -105,7 +124,12 @@ public Result apply(ProjectNode parent, Captures captures, Context context) } } for (VariableReferenceExpression input : inputs) { - childAssignments.put(identityAsSymbolReference(input)); + if (allTranslated) { + childAssignments.put(input, input); + } + else { + childAssignments.put(identityAsSymbolReference(input)); + } } return Result.ofPlanNode( @@ -118,16 +142,18 @@ public Result apply(ProjectNode parent, Captures captures, Context context) Assignments.copyOf(parentAssignments))); } - private Expression inlineReferences(Expression expression, Assignments assignments, TypeProvider types) + private RowExpression inlineReferences(RowExpression expression, Assignments assignments, TypeProvider types) { - Function mapping = variable -> { - if (assignments.get(variable) == null) { - return new SymbolReference(variable.getName()); - } - return castToExpression(assignments.get(variable)); - }; - - return ExpressionVariableInliner.inlineVariables(mapping, expression, types); + if (isExpression(expression)) { + Function mapping = variable -> { + if (assignments.get(variable) == null) { + return new SymbolReference(variable.getName()); + } + return castToExpression(assignments.get(variable)); + }; + return castToRowExpression(ExpressionVariableInliner.inlineVariables(mapping, castToExpression(expression), types)); + } + return RowExpressionVariableInliner.inlineVariables(variable -> assignments.getMap().getOrDefault(variable, variable), expression); } private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectNode child, Context context) @@ -141,25 +167,25 @@ private Sets.SetView extractInliningTargets(Project // which come from the child, as opposed to an enclosing scope. Set childOutputSet = ImmutableSet.copyOf(child.getOutputVariables()); + TypeProvider types = context.getVariableAllocator().getTypes(); Map dependencies = parent.getAssignments() .getExpressions() .stream() - .map(OriginalExpressionUtils::castToExpression) - .flatMap(expression -> VariablesExtractor.extractAll(expression, context.getVariableAllocator().getTypes()).stream()) + .flatMap(expression -> extractDependencies(expression, context.getVariableAllocator().getTypes()).stream()) .filter(childOutputSet::contains) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); // find references to simple constants Set constants = dependencies.keySet().stream() - .filter(input -> castToExpression(child.getAssignments().get(input)) instanceof Literal) + .filter(input -> isConstant(child.getAssignments().get(input))) .collect(toSet()); // exclude any complex inputs to TRY expressions. Inlining them would potentially // change the semantics of those expressions Set tryArguments = parent.getAssignments() .getExpressions().stream() - .flatMap(expression -> extractTryArguments(castToExpression(expression), context.getVariableAllocator().getTypes()).stream()) + .flatMap(expression -> extractTryArguments(expression, types).stream()) .collect(toSet()); Set singletons = dependencies.entrySet().stream() @@ -172,12 +198,43 @@ private Sets.SetView extractInliningTargets(Project return Sets.union(singletons, constants); } - private Set extractTryArguments(Expression expression, TypeProvider types) + private Set extractTryArguments(RowExpression expression, TypeProvider types) { - return AstUtils.preOrder(expression) - .filter(TryExpression.class::isInstance) - .map(TryExpression.class::cast) - .flatMap(tryExpression -> VariablesExtractor.extractAll(tryExpression, types).stream()) - .collect(toSet()); + if (isExpression(expression)) { + return AstUtils.preOrder(castToExpression(expression)) + .filter(TryExpression.class::isInstance) + .map(TryExpression.class::cast) + .flatMap(tryExpression -> VariablesExtractor.extractAll(tryExpression, types).stream()) + .collect(toSet()); + } + ImmutableSet.Builder builder = ImmutableSet.builder(); + expression.accept(new DefaultRowExpressionTraversalVisitor>() + { + @Override + public Void visitCall(CallExpression call, ImmutableSet.Builder context) + { + if (functionResolution.isTryFunction(call.getFunctionHandle())) { + context.addAll(VariablesExtractor.extractAll(call)); + } + return super.visitCall(call, context); + } + }, builder); + return builder.build(); + } + + private static List extractDependencies(RowExpression expression, TypeProvider types) + { + if (isExpression(expression)) { + return VariablesExtractor.extractAll(castToExpression(expression), types); + } + return VariablesExtractor.extractAll(expression); + } + + private static boolean isConstant(RowExpression expression) + { + if (isExpression(expression)) { + return castToExpression(expression) instanceof Literal; + } + return expression instanceof ConstantExpression; } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java index 036e0f7758dd..bc1e26d73aff 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java @@ -28,6 +28,7 @@ import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static java.util.Arrays.asList; import static java.util.Collections.singletonMap; @@ -68,8 +69,12 @@ public static Assignments identityAssignments(Collection p.project( Assignments.builder() @@ -70,10 +70,55 @@ public void test() values(ImmutableMap.of("x", 0))))); } + @Test + public void testRowExpression() + { + // TODO add testing to expressions that need desugaring like 'try' + tester().assertThat(new InlineProjections(getFunctionManager())) + .on(p -> { + p.variable("symbol"); + p.variable("complex"); + p.variable("literal"); + p.variable("complex_2"); + p.variable("x"); + return p.project( + Assignments.builder() + .put(p.variable("identity"), p.rowExpression("symbol")) // identity + .put(p.variable("multi_complex_1"), p.rowExpression("complex + 1")) // complex expression referenced multiple times + .put(p.variable("multi_complex_2"), p.rowExpression("complex + 2")) // complex expression referenced multiple times + .put(p.variable("multi_literal_1"), p.rowExpression("literal + 1")) // literal referenced multiple times + .put(p.variable("multi_literal_2"), p.rowExpression("literal + 2")) // literal referenced multiple times + .put(p.variable("single_complex"), p.rowExpression("complex_2 + 2")) // complex expression reference only once + .build(), + p.project(Assignments.builder() + .put(p.variable("symbol"), p.rowExpression("x")) + .put(p.variable("complex"), p.rowExpression("x * 2")) + .put(p.variable("literal"), p.rowExpression("1")) + .put(p.variable("complex_2"), p.rowExpression("x - 1")) + .build(), + p.values(p.variable("x")))); + }) + .matches( + project( + ImmutableMap.builder() + .put("out1", PlanMatchPattern.expression("x")) + .put("out2", PlanMatchPattern.expression("y + 1")) + .put("out3", PlanMatchPattern.expression("y + 2")) + .put("out4", PlanMatchPattern.expression("1 + 1")) + .put("out5", PlanMatchPattern.expression("1 + 2")) + .put("out6", PlanMatchPattern.expression("x - 1 + 2")) + .build(), + project( + ImmutableMap.of( + "x", PlanMatchPattern.expression("x"), + "y", PlanMatchPattern.expression("x * 2")), + values(ImmutableMap.of("x", 0))))); + } + @Test public void testIdentityProjections() { - tester().assertThat(new InlineProjections()) + tester().assertThat(new InlineProjections(getFunctionManager())) .on(p -> p.project( assignment(p.variable("output"), expression("value")), @@ -86,7 +131,7 @@ public void testIdentityProjections() @Test public void testSubqueryProjections() { - tester().assertThat(new InlineProjections()) + tester().assertThat(new InlineProjections(getFunctionManager())) .on(p -> p.project( identityAssignmentsAsSymbolReferences(p.variable("fromOuterScope"), p.variable("value")),