Skip to content

Commit

Permalink
Move TranslateExpressions above last InlineProjections
Browse files Browse the repository at this point in the history
  • Loading branch information
hellium01 authored and highker committed Aug 22, 2019
1 parent 8b5d746 commit 0f2b3dd
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public PlanOptimizers(
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(
new InlineProjections(),
new InlineProjections(metadata.getFunctionManager()),
new RemoveRedundantIdentityProjections()));

IterativeOptimizer projectionPushDown = new IterativeOptimizer(
Expand Down Expand Up @@ -347,7 +347,7 @@ public PlanOptimizers(
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(
new InlineProjections(),
new InlineProjections(metadata.getFunctionManager()),
new RemoveRedundantIdentityProjections(),
new TransformCorrelatedSingleRowSubqueryToProject())),
new CheckSubqueryNodesAreRewritten(),
Expand Down Expand Up @@ -441,7 +441,7 @@ public PlanOptimizers(
ImmutableSet.<Rule<?>>builder()
.add(new RemoveRedundantIdentityProjections())
.addAll(new ExtractSpatialJoins(metadata, splitManager, pageSourceManager, sqlParser).rules())
.add(new InlineProjections())
.add(new InlineProjections(metadata.getFunctionManager()))
.build()));

if (!forceSingleNode) {
Expand Down Expand Up @@ -480,24 +480,24 @@ public PlanOptimizers(
builder.add(new UnaliasSymbolReferences()); // Run unalias after merging projections to simplify projections more efficiently
builder.add(new PruneUnreferencedOutputs());

// TODO: move this before optimization if possible!!
// Replace all expressions with row expressions
builder.add(new IterativeOptimizer(
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.<Rule<?>>builder()
.add(new RemoveRedundantIdentityProjections())
.add(new PushRemoteExchangeThroughAssignUniqueId())
.add(new InlineProjections())
.build()));
new TranslateExpressions(metadata, sqlParser).rules()));
// After this point, all planNodes should not contain OriginalExpression

// TODO: move this before optimization if possible!!
// Replace all expressions with row expressions
builder.add(new IterativeOptimizer(
ruleStats,
statsCalculator,
costCalculator,
new TranslateExpressions(metadata, sqlParser).rules()));
// After this point, all planNodes should not contain OriginalExpression
ImmutableSet.<Rule<?>>builder()
.add(new RemoveRedundantIdentityProjections())
.add(new PushRemoteExchangeThroughAssignUniqueId())
.add(new InlineProjections(metadata.getFunctionManager()))
.build()));

// Optimizers above this don't understand local exchanges, so be careful moving this.
builder.add(new AddLocalExchanges(metadata, sqlParser));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -30,23 +31,28 @@ public final class RowExpressionVariableInliner
extends RowExpressionRewriter<Void>
{
private final Set<String> excludedNames = new HashSet<>();
private final Map<VariableReferenceExpression, RowExpression> mapping;
private final Function<VariableReferenceExpression, RowExpression> mapping;

private RowExpressionVariableInliner(Map<VariableReferenceExpression, RowExpression> mapping)
private RowExpressionVariableInliner(Function<VariableReferenceExpression, RowExpression> mapping)
{
this.mapping = mapping;
}

public static RowExpression inlineVariables(Map<VariableReferenceExpression, RowExpression> mapping, RowExpression expression)
public static RowExpression inlineVariables(Function<VariableReferenceExpression, RowExpression> mapping, RowExpression expression)
{
return RowExpressionTreeRewriter.rewriteWith(new RowExpressionVariableInliner(mapping), expression);
}

public static RowExpression inlineVariables(Map<VariableReferenceExpression, RowExpression> mapping, RowExpression expression)
{
return inlineVariables(mapping::get, expression);
}

@Override
public RowExpression rewriteVariableReference(VariableReferenceExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
if (!excludedNames.contains(node.getName())) {
RowExpression result = mapping.get(node);
RowExpression result = mapping.apply(node);
checkState(result != null, "Cannot resolve symbol %s", node.getName());
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.ExpressionVariableInliner;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.Assignments.Builder;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Literal;
Expand All @@ -34,6 +40,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
Expand All @@ -46,6 +53,7 @@
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static java.util.stream.Collectors.toSet;

/**
Expand All @@ -62,6 +70,13 @@ public class InlineProjections
private static final Pattern<ProjectNode> PATTERN = project()
.with(source().matching(project().capturedAs(CHILD)));

private final FunctionResolution functionResolution;

public InlineProjections(FunctionManager functionManager)
{
this.functionResolution = new FunctionResolution(functionManager);
}

@Override
public Pattern<ProjectNode> getPattern()
{
Expand All @@ -84,18 +99,22 @@ public Result apply(ProjectNode parent, Captures captures, Context context)
.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> castToRowExpression(inlineReferences(castToExpression(entry.getValue()), assignments, context.getVariableAllocator().getTypes()))));
entry -> inlineReferences(entry.getValue(), assignments, context.getVariableAllocator().getTypes())));

// Synthesize identity assignments for the inputs of expressions that were inlined
// to place in the child projection.
// If all assignments end up becoming identity assignments, they'll get pruned by
// other rules
boolean allTranslated = child.getAssignments().entrySet()
.stream()
.map(Map.Entry::getValue)
.noneMatch(OriginalExpressionUtils::isExpression);

Set<VariableReferenceExpression> inputs = child.getAssignments()
.entrySet().stream()
.filter(entry -> targets.contains(entry.getKey()))
.map(Map.Entry::getValue)
.map(OriginalExpressionUtils::castToExpression)
.flatMap(entry -> VariablesExtractor.extractAll(entry, context.getVariableAllocator().getTypes()).stream())
.flatMap(expression -> extractDependencies(expression, context.getVariableAllocator().getTypes()).stream())
.collect(toSet());

Builder childAssignments = Assignments.builder();
Expand All @@ -105,7 +124,12 @@ public Result apply(ProjectNode parent, Captures captures, Context context)
}
}
for (VariableReferenceExpression input : inputs) {
childAssignments.put(identityAsSymbolReference(input));
if (allTranslated) {
childAssignments.put(input, input);
}
else {
childAssignments.put(identityAsSymbolReference(input));
}
}

return Result.ofPlanNode(
Expand All @@ -118,16 +142,18 @@ public Result apply(ProjectNode parent, Captures captures, Context context)
Assignments.copyOf(parentAssignments)));
}

private Expression inlineReferences(Expression expression, Assignments assignments, TypeProvider types)
private RowExpression inlineReferences(RowExpression expression, Assignments assignments, TypeProvider types)
{
Function<VariableReferenceExpression, Expression> mapping = variable -> {
if (assignments.get(variable) == null) {
return new SymbolReference(variable.getName());
}
return castToExpression(assignments.get(variable));
};

return ExpressionVariableInliner.inlineVariables(mapping, expression, types);
if (isExpression(expression)) {
Function<VariableReferenceExpression, Expression> mapping = variable -> {
if (assignments.get(variable) == null) {
return new SymbolReference(variable.getName());
}
return castToExpression(assignments.get(variable));
};
return castToRowExpression(ExpressionVariableInliner.inlineVariables(mapping, castToExpression(expression), types));
}
return RowExpressionVariableInliner.inlineVariables(variable -> assignments.getMap().getOrDefault(variable, variable), expression);
}

private Sets.SetView<VariableReferenceExpression> extractInliningTargets(ProjectNode parent, ProjectNode child, Context context)
Expand All @@ -141,25 +167,25 @@ private Sets.SetView<VariableReferenceExpression> extractInliningTargets(Project
// which come from the child, as opposed to an enclosing scope.

Set<VariableReferenceExpression> childOutputSet = ImmutableSet.copyOf(child.getOutputVariables());
TypeProvider types = context.getVariableAllocator().getTypes();

Map<VariableReferenceExpression, Long> dependencies = parent.getAssignments()
.getExpressions()
.stream()
.map(OriginalExpressionUtils::castToExpression)
.flatMap(expression -> VariablesExtractor.extractAll(expression, context.getVariableAllocator().getTypes()).stream())
.flatMap(expression -> extractDependencies(expression, context.getVariableAllocator().getTypes()).stream())
.filter(childOutputSet::contains)
.collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));

// find references to simple constants
Set<VariableReferenceExpression> constants = dependencies.keySet().stream()
.filter(input -> castToExpression(child.getAssignments().get(input)) instanceof Literal)
.filter(input -> isConstant(child.getAssignments().get(input)))
.collect(toSet());

// exclude any complex inputs to TRY expressions. Inlining them would potentially
// change the semantics of those expressions
Set<VariableReferenceExpression> tryArguments = parent.getAssignments()
.getExpressions().stream()
.flatMap(expression -> extractTryArguments(castToExpression(expression), context.getVariableAllocator().getTypes()).stream())
.flatMap(expression -> extractTryArguments(expression, types).stream())
.collect(toSet());

Set<VariableReferenceExpression> singletons = dependencies.entrySet().stream()
Expand All @@ -172,12 +198,43 @@ private Sets.SetView<VariableReferenceExpression> extractInliningTargets(Project
return Sets.union(singletons, constants);
}

private Set<VariableReferenceExpression> extractTryArguments(Expression expression, TypeProvider types)
private Set<VariableReferenceExpression> extractTryArguments(RowExpression expression, TypeProvider types)
{
return AstUtils.preOrder(expression)
.filter(TryExpression.class::isInstance)
.map(TryExpression.class::cast)
.flatMap(tryExpression -> VariablesExtractor.extractAll(tryExpression, types).stream())
.collect(toSet());
if (isExpression(expression)) {
return AstUtils.preOrder(castToExpression(expression))
.filter(TryExpression.class::isInstance)
.map(TryExpression.class::cast)
.flatMap(tryExpression -> VariablesExtractor.extractAll(tryExpression, types).stream())
.collect(toSet());
}
ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();
expression.accept(new DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>>()
{
@Override
public Void visitCall(CallExpression call, ImmutableSet.Builder<VariableReferenceExpression> context)
{
if (functionResolution.isTryFunction(call.getFunctionHandle())) {
context.addAll(VariablesExtractor.extractAll(call));
}
return super.visitCall(call, context);
}
}, builder);
return builder.build();
}

private static List<VariableReferenceExpression> extractDependencies(RowExpression expression, TypeProvider types)
{
if (isExpression(expression)) {
return VariablesExtractor.extractAll(castToExpression(expression), types);
}
return VariablesExtractor.extractAll(expression);
}

private static boolean isConstant(RowExpression expression)
{
if (isExpression(expression)) {
return castToExpression(expression) instanceof Literal;
}
return expression instanceof ConstantExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonMap;

Expand Down Expand Up @@ -68,8 +69,12 @@ public static Assignments identityAssignments(Collection<VariableReferenceExpres
public static boolean isIdentity(Assignments assignments, VariableReferenceExpression output)
{
//TODO this will be checking against VariableExpression once getOutput returns VariableReferenceExpression
Expression expression = castToExpression(assignments.get(output));
return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName());
RowExpression value = assignments.get(output);
if (isExpression(value)) {
Expression expression = castToExpression(value);
return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName());
}
return value instanceof VariableReferenceExpression && ((VariableReferenceExpression) value).getName().equals(output.getName());
}

@Deprecated
Expand Down
Loading

0 comments on commit 0f2b3dd

Please sign in to comment.