From b541da6a277ed4eadc6bd300e1ee91dd12af1df4 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Fri, 27 Sep 2019 13:21:59 +0200 Subject: [PATCH] Add join type and filter to UnnestNode This is a preparatory step for supporting LEFT, RIGHT, FULL and INNER JOIN involving UNNEST with non-trivial join conditions. --- .../planner/EffectivePredicateExtractor.java | 20 +++++++++ .../sql/planner/ExpressionExtractor.java | 8 ++++ .../sql/planner/LocalExecutionPlanner.java | 4 +- .../sql/planner/RelationPlanner.java | 41 ++++++++++++++----- .../sql/planner/SubqueryPlanner.java | 15 +++++++ .../iterative/rule/ExtractSpatialJoins.java | 3 +- .../HashGenerationOptimizer.java | 3 +- .../optimizations/PredicatePushDown.java | 6 ++- .../optimizations/PropertyDerivations.java | 15 ++++++- .../PruneUnreferencedOutputs.java | 8 +++- .../StreamPropertyDerivations.java | 13 +++++- .../UnaliasSymbolReferences.java | 9 +++- .../sql/planner/plan/UnnestNode.java | 23 ++++++++--- .../sql/planner/planprinter/PlanPrinter.java | 18 ++++++-- .../sanity/ValidateDependenciesChecker.java | 16 +++++--- .../io/prestosql/util/GraphvizPrinter.java | 17 ++++++-- 16 files changed, 184 insertions(+), 35 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java index 526ef79c6bf7..03cf2ab26c97 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java @@ -40,6 +40,7 @@ import io.prestosql.sql.planner.plan.TableScanNode; import io.prestosql.sql.planner.plan.TopNNode; import io.prestosql.sql.planner.plan.UnionNode; +import io.prestosql.sql.planner.plan.UnnestNode; import io.prestosql.sql.planner.plan.ValuesNode; import io.prestosql.sql.planner.plan.WindowNode; import io.prestosql.sql.tree.ComparisonExpression; @@ -240,6 +241,25 @@ public Expression visitUnion(UnionNode node, Void context) return deriveCommonPredicates(node, source -> node.outputSymbolMap(source).entries()); } + @Override + public Expression visitUnnest(UnnestNode node, Void context) + { + Expression sourcePredicate = node.getSource().accept(this, context); + + switch (node.getJoinType()) { + case INNER: + case LEFT: + return pullExpressionThroughSymbols( + combineConjuncts(node.getFilter().orElse(TRUE_LITERAL), sourcePredicate), + node.getOutputSymbols()); + case RIGHT: + case FULL: + return TRUE_LITERAL; + default: + throw new UnsupportedOperationException("Unknown UNNEST join type: " + node.getJoinType()); + } + } + @Override public Expression visitJoin(JoinNode node, Void context) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionExtractor.java b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionExtractor.java index 1f4786603df9..bbf6ee85c41f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionExtractor.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/ExpressionExtractor.java @@ -23,6 +23,7 @@ import io.prestosql.sql.planner.plan.JoinNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.UnnestNode; import io.prestosql.sql.planner.plan.ValuesNode; import io.prestosql.sql.tree.Expression; @@ -124,6 +125,13 @@ public Void visitJoin(JoinNode node, Void context) return super.visitJoin(node, context); } + @Override + public Void visitUnnest(UnnestNode node, Void context) + { + node.getFilter().ifPresent(consumer); + return super.visitUnnest(node, context); + } + @Override public Void visitValues(ValuesNode node, Void context) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 375edfeb1b2e..9d5efb753e20 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -261,6 +261,7 @@ import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.prestosql.sql.planner.plan.JoinNode.Type.FULL; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT; import static io.prestosql.sql.planner.plan.JoinNode.Type.RIGHT; import static io.prestosql.sql.planner.plan.TableWriterNode.CreateTarget; import static io.prestosql.sql.planner.plan.TableWriterNode.InsertTarget; @@ -1361,6 +1362,7 @@ public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext outputMappings.put(ordinalitySymbol.get(), channel); channel++; } + boolean outer = node.getJoinType() == LEFT || node.getJoinType() == FULL; OperatorFactory operatorFactory = new UnnestOperatorFactory( context.getNextOperatorId(), node.getId(), @@ -1369,7 +1371,7 @@ public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext unnestChannels, unnestTypes.build(), ordinalityType.isPresent(), - node.isOuter()); + outer); return new PhysicalOperation(operatorFactory, outputMappings.build(), context, source); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java index eafb67809a08..ec457c39541d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java @@ -101,7 +101,6 @@ import static io.prestosql.sql.tree.Join.Type.CROSS; import static io.prestosql.sql.tree.Join.Type.IMPLICIT; import static io.prestosql.sql.tree.Join.Type.INNER; -import static io.prestosql.sql.tree.Join.Type.RIGHT; import static java.util.Objects.requireNonNull; class RelationPlanner @@ -224,15 +223,11 @@ protected RelationPlan visitJoin(Join node, Void context) Optional unnest = getUnnest(node.getRight()); if (unnest.isPresent()) { if (node.getType() == CROSS || node.getType() == IMPLICIT) { - return planJoinUnnest(leftPlan, node, unnest.get(), false); + return planJoinUnnest(leftPlan, node, unnest.get()); } checkState(node.getCriteria().isPresent(), "missing Join criteria"); if (node.getCriteria().get() instanceof JoinOn && ((JoinOn) node.getCriteria().get()).getExpression().equals(TRUE_LITERAL)) { - if (node.getType() == RIGHT || node.getType() == INNER) { - return planJoinUnnest(leftPlan, node, unnest.get(), false); - } - // LEFT or FULL join - return planJoinUnnest(leftPlan, node, unnest.get(), true); + return planJoinUnnest(leftPlan, node, unnest.get()); } throw notSupportedException(unnest.get(), "UNNEST in conditional JOIN"); } @@ -610,7 +605,7 @@ private static boolean isEqualComparisonExpression(Expression conjunct) return conjunct instanceof ComparisonExpression && ((ComparisonExpression) conjunct).getOperator() == ComparisonExpression.Operator.EQUAL; } - private RelationPlan planJoinUnnest(RelationPlan leftPlan, Join joinNode, Unnest node, boolean outer) + private RelationPlan planJoinUnnest(RelationPlan leftPlan, Join joinNode, Unnest node) { RelationType unnestOutputDescriptor = analysis.getOutputDescriptor(node); // Create symbols for the result of unnesting @@ -655,7 +650,33 @@ else if (type instanceof MapType) { Optional ordinalitySymbol = node.isWithOrdinality() ? Optional.of(unnestedSymbolsIterator.next()) : Optional.empty(); checkState(!unnestedSymbolsIterator.hasNext(), "Not all output symbols were matched with input symbols"); - UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), projectNode, leftPlan.getFieldMappings(), unnestSymbols.build(), ordinalitySymbol, outer); + Optional filterExpression = Optional.empty(); + if (joinNode.getCriteria().isPresent()) { + JoinCriteria criteria = joinNode.getCriteria().get(); + if (criteria instanceof NaturalJoin) { + throw notSupportedException(joinNode, "Natural join involving UNNEST not supported"); + } + if (criteria instanceof JoinUsing) { + throw notSupportedException(joinNode, "USING not supported for join involving UNNEST"); + } + Expression filter = (Expression) getOnlyElement(criteria.getNodes()); + if (filter.equals(TRUE_LITERAL)) { + filterExpression = Optional.of(filter); + } + else { //TODO rewrite filter to support non-trivial join criteria + throw notSupportedException(joinNode, "JOIN involving UNNEST on condition other than TRUE"); + } + } + + UnnestNode unnestNode = new UnnestNode( + idAllocator.getNextId(), + projectNode, + leftPlan.getFieldMappings(), + unnestSymbols.build(), + ordinalitySymbol, + JoinNode.Type.typeConvert(joinNode.getType()), + filterExpression); + return new RelationPlan(unnestNode, analysis.getScope(joinNode), unnestNode.getOutputSymbols()); } @@ -757,7 +778,7 @@ else if (type instanceof MapType) { checkState(!unnestedSymbolsIterator.hasNext(), "Not all output symbols were matched with input symbols"); ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), argumentSymbols.build(), ImmutableList.of(values.build())); - UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), valuesNode, ImmutableList.of(), unnestSymbols.build(), ordinalitySymbol, false); + UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), valuesNode, ImmutableList.of(), unnestSymbols.build(), ordinalitySymbol, JoinNode.Type.INNER, Optional.empty()); return new RelationPlan(unnestNode, scope, unnestedSymbols); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java index ac9d681f5e6f..a77d571c7df0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java @@ -28,6 +28,7 @@ import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.planner.plan.SimplePlanRewriter; +import io.prestosql.sql.planner.plan.UnnestNode; import io.prestosql.sql.planner.plan.ValuesNode; import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; import io.prestosql.sql.tree.DereferenceExpression; @@ -582,6 +583,20 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) return new FilterNode(node.getId(), rewrittenNode.getSource(), replaceExpression(rewrittenNode.getPredicate(), mapping)); } + @Override + public PlanNode visitUnnest(UnnestNode node, RewriteContext context) + { + UnnestNode rewrittenNode = (UnnestNode) context.defaultRewrite(node); + return new UnnestNode( + node.getId(), + rewrittenNode.getSource(), + rewrittenNode.getReplicateSymbols(), + rewrittenNode.getUnnestSymbols(), + rewrittenNode.getOrdinalitySymbol(), + rewrittenNode.getJoinType(), + rewrittenNode.getFilter().map(expression -> replaceExpression(expression, mapping))); + } + @Override public PlanNode visitValues(ValuesNode node, RewriteContext context) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java index dda1bc74f477..f9a9bdcf6a53 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -605,7 +605,8 @@ private static PlanNode addPartitioningNodes(Metadata metadata, Context context, node.getOutputSymbols(), ImmutableMap.of(partitionsSymbol, ImmutableList.of(partitionSymbol)), Optional.empty(), - false); + INNER, + Optional.empty()); } private static boolean containsNone(Collection values, Collection testValues) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java index e57a1531ba17..aa4e4233777e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/HashGenerationOptimizer.java @@ -662,7 +662,8 @@ public PlanWithProperties visitUnnest(UnnestNode node, HashComputationSet parent .build(), node.getUnnestSymbols(), node.getOrdinalitySymbol(), - node.isOuter()), + node.getJoinType(), + node.getFilter()), hashSymbols); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java index 606a79e224ab..0d709345c8a4 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PredicatePushDown.java @@ -1290,7 +1290,11 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { Expression inheritedPredicate = context.get(); + if (node.getJoinType() == RIGHT || node.getJoinType() == FULL) { + return new FilterNode(idAllocator.getNextId(), node, inheritedPredicate); + } + //TODO for LEFT or INNER join type, push down UnnestNode's filter on replicate symbols EqualityInference equalityInference = EqualityInference.newInstance(inheritedPredicate); List pushdownConjuncts = new ArrayList<>(); @@ -1324,7 +1328,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) PlanNode output = node; if (rewrittenSource != node.getSource()) { - output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getUnnestSymbols(), node.getOrdinalitySymbol(), node.isOuter()); + output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getUnnestSymbols(), node.getOrdinalitySymbol(), node.getJoinType(), node.getFilter()); } if (!postUnnestConjuncts.isEmpty()) { output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postUnnestConjuncts)); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java index 971b456b2f6a..5698842736b8 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PropertyDerivations.java @@ -682,12 +682,25 @@ public ActualProperties visitUnnest(UnnestNode node, List inpu { Set passThroughInputs = ImmutableSet.copyOf(node.getReplicateSymbols()); - return Iterables.getOnlyElement(inputProperties).translate(column -> { + ActualProperties translatedProperties = Iterables.getOnlyElement(inputProperties).translate(column -> { if (passThroughInputs.contains(column)) { return Optional.of(column); } return Optional.empty(); }); + + switch (node.getJoinType()) { + case INNER: + case LEFT: + return translatedProperties; + case RIGHT: + case FULL: + return ActualProperties.builderFrom(translatedProperties) + .local(ImmutableList.of()) + .build(); + default: + throw new UnsupportedOperationException("Unknown UNNEST join type: " + node.getJoinType()); + } } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java index 5bd1b3fb4d16..51ded7dcfcae 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -524,9 +524,15 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext> context ImmutableSet.Builder expectedInputs = ImmutableSet.builder() .addAll(replicateSymbols) .addAll(unnestSymbols.keySet()); + ImmutableSet.Builder unnestedSymbols = ImmutableSet.builder(); + for (List symbols : unnestSymbols.values()) { + unnestedSymbols.addAll(symbols); + } + Set expectedFilterSymbols = Sets.difference(SymbolsExtractor.extractUnique(node.getFilter().orElse(TRUE_LITERAL)), unnestedSymbols.build()); + expectedInputs.addAll(expectedFilterSymbols); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); - return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol, node.isOuter()); + return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol, node.getJoinType(), node.getFilter()); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java index 2cee31bfddd9..99a0cd7cee1f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/StreamPropertyDerivations.java @@ -430,12 +430,23 @@ public StreamProperties visitUnnest(UnnestNode node, List inpu // We can describe properties in terms of inputs that are projected unmodified (i.e., not the unnested symbols) Set passThroughInputs = ImmutableSet.copyOf(node.getReplicateSymbols()); - return properties.translate(column -> { + StreamProperties translatedProperties = properties.translate(column -> { if (passThroughInputs.contains(column)) { return Optional.of(column); } return Optional.empty(); }); + + switch (node.getJoinType()) { + case INNER: + case LEFT: + return translatedProperties; + case RIGHT: + case FULL: + return translatedProperties.unordered(true); + default: + throw new UnsupportedOperationException("Unknown UNNEST join type: " + node.getJoinType()); + } } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java index 317751c6eb4d..fb1f3a26090b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -183,7 +183,14 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext context) for (Map.Entry> entry : node.getUnnestSymbols().entrySet()) { builder.put(canonicalize(entry.getKey()), entry.getValue()); } - return new UnnestNode(node.getId(), source, canonicalizeAndDistinct(node.getReplicateSymbols()), builder.build(), node.getOrdinalitySymbol(), node.isOuter()); + return new UnnestNode( + node.getId(), + source, + canonicalizeAndDistinct(node.getReplicateSymbols()), + builder.build(), + node.getOrdinalitySymbol(), + node.getJoinType(), + node.getFilter().map(this::canonicalize)); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/UnnestNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/UnnestNode.java index 6c62c7f5ea99..c77cfff18daf 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/UnnestNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/UnnestNode.java @@ -19,6 +19,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.plan.JoinNode.Type; +import io.prestosql.sql.tree.Expression; import javax.annotation.concurrent.Immutable; @@ -37,7 +39,8 @@ public class UnnestNode private final List replicateSymbols; private final Map> unnestSymbols; private final Optional ordinalitySymbol; - private final boolean outer; + private final Type joinType; + private final Optional filter; @JsonCreator public UnnestNode( @@ -46,7 +49,8 @@ public UnnestNode( @JsonProperty("replicateSymbols") List replicateSymbols, @JsonProperty("unnestSymbols") Map> unnestSymbols, @JsonProperty("ordinalitySymbol") Optional ordinalitySymbol, - @JsonProperty("outer") boolean outer) + @JsonProperty("joinType") Type joinType, + @JsonProperty("filter") Optional filter) { super(id); this.source = requireNonNull(source, "source is null"); @@ -60,7 +64,8 @@ public UnnestNode( } this.unnestSymbols = builder.build(); this.ordinalitySymbol = requireNonNull(ordinalitySymbol, "ordinalitySymbol is null"); - this.outer = outer; + this.joinType = requireNonNull(joinType, "type is null"); + this.filter = requireNonNull(filter, "filter is null"); } @Override @@ -98,9 +103,15 @@ public Optional getOrdinalitySymbol() } @JsonProperty - public boolean isOuter() + public Type getJoinType() { - return outer; + return joinType; + } + + @JsonProperty + public Optional getFilter() + { + return filter; } @Override @@ -118,6 +129,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new UnnestNode(getId(), Iterables.getOnlyElement(newChildren), replicateSymbols, unnestSymbols, ordinalitySymbol, outer); + return new UnnestNode(getId(), Iterables.getOnlyElement(newChildren), replicateSymbols, unnestSymbols, ordinalitySymbol, joinType, filter); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java b/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java index de57d0e802be..7f4352b7dad0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java @@ -851,9 +851,21 @@ private void printTableScanInfo(NodeRepresentation nodeOutput, TableScanNode nod @Override public Void visitUnnest(UnnestNode node, Void context) { - addNode(node, - "Unnest", - format("[replicate=%s, unnest=%s]", formatOutputs(types, node.getReplicateSymbols()), formatOutputs(types, node.getUnnestSymbols().keySet()))); + String name; + if (node.getFilter().isPresent()) { + name = node.getJoinType().getJoinLabel() + " Unnest"; + } + else if (!node.getReplicateSymbols().isEmpty()) { + name = "CrossJoin Unnest"; + } + else { + name = "Unnest"; + } + addNode( + node, + name, + format("[replicate=%s, unnest=%s", formatOutputs(types, node.getReplicateSymbols()), formatOutputs(types, node.getUnnestSymbols().keySet())) + + (node.getFilter().isPresent() ? format(", filter=%s]", node.getFilter().get().toString()) : "]")); return processChildren(node, context); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java index 751053680ec0..10f231f919ff 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; @@ -78,6 +79,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; /** * Ensures that all dependencies (i.e., symbols in expressions) for a plan node are provided by its source nodes @@ -484,12 +486,16 @@ public Void visitUnnest(UnnestNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); - Set required = ImmutableSet.builder() + ImmutableSet.Builder required = ImmutableSet.builder() .addAll(node.getReplicateSymbols()) - .addAll(node.getUnnestSymbols().keySet()) - .build(); - - checkDependencies(source.getOutputSymbols(), required, "Invalid node. Dependencies (%s) not in source plan output (%s)", required, source.getOutputSymbols()); + .addAll(node.getUnnestSymbols().keySet()); + ImmutableSet.Builder unnestedSymbols = ImmutableSet.builder(); + for (List symbols : node.getUnnestSymbols().values()) { + unnestedSymbols.addAll(symbols); + } + Set expectedFilterSymbols = Sets.difference(SymbolsExtractor.extractUnique(node.getFilter().orElse(TRUE_LITERAL)), unnestedSymbols.build()); + required.addAll(expectedFilterSymbols); + checkDependencies(source.getOutputSymbols(), required.build(), "Invalid node. Dependencies (%s) not in source plan output (%s)", required, source.getOutputSymbols()); return null; } diff --git a/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java b/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java index 16ba5bd5fe55..2058a68ef475 100644 --- a/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java @@ -387,12 +387,23 @@ public Void visitProject(ProjectNode node, Void context) @Override public Void visitUnnest(UnnestNode node, Void context) { - if (!node.getOrdinalitySymbol().isPresent()) { - printNode(node, format("Unnest[%s]", node.getUnnestSymbols().keySet()), NODE_COLORS.get(NodeType.UNNEST)); + StringBuilder label = new StringBuilder(); + if (node.getFilter().isPresent()) { + label.append(node.getJoinType().getJoinLabel()) + .append(" Unnest"); + } + else if (!node.getReplicateSymbols().isEmpty()) { + label.append("CrossJoin Unnest"); } else { - printNode(node, format("Unnest[%s (ordinality)]", node.getUnnestSymbols().keySet()), NODE_COLORS.get(NodeType.UNNEST)); + label.append("Unnest"); } + label.append(format(" [%s", node.getUnnestSymbols().keySet())) + .append(node.getOrdinalitySymbol().isPresent() ? " (ordinality)]" : "]"); + + String details = node.getFilter().isPresent() ? " filter " + node.getFilter().get().toString() : ""; + + printNode(node, label.toString(), details, NODE_COLORS.get(NodeType.UNNEST)); return node.getSource().accept(this, context); }