Skip to content

Commit

Permalink
Add join type and filter to UnnestNode
Browse files Browse the repository at this point in the history
This is a preparatory step for supporting LEFT, RIGHT, FULL
and INNER JOIN involving  UNNEST with non-trivial join conditions.
  • Loading branch information
kasiafi authored and martint committed Sep 28, 2019
1 parent 8638a8c commit b541da6
Show file tree
Hide file tree
Showing 16 changed files with 184 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.ComparisonExpression;
Expand Down Expand Up @@ -240,6 +241,25 @@ public Expression visitUnion(UnionNode node, Void context)
return deriveCommonPredicates(node, source -> node.outputSymbolMap(source).entries());
}

@Override
public Expression visitUnnest(UnnestNode node, Void context)
{
Expression sourcePredicate = node.getSource().accept(this, context);

switch (node.getJoinType()) {
case INNER:
case LEFT:
return pullExpressionThroughSymbols(
combineConjuncts(node.getFilter().orElse(TRUE_LITERAL), sourcePredicate),
node.getOutputSymbols());
case RIGHT:
case FULL:
return TRUE_LITERAL;
default:
throw new UnsupportedOperationException("Unknown UNNEST join type: " + node.getJoinType());
}
}

@Override
public Expression visitJoin(JoinNode node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.tree.Expression;

Expand Down Expand Up @@ -124,6 +125,13 @@ public Void visitJoin(JoinNode node, Void context)
return super.visitJoin(node, context);
}

@Override
public Void visitUnnest(UnnestNode node, Void context)
{
node.getFilter().ifPresent(consumer);
return super.visitUnnest(node, context);
}

@Override
public Void visitValues(ValuesNode node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@
import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.prestosql.sql.planner.plan.JoinNode.Type.FULL;
import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER;
import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT;
import static io.prestosql.sql.planner.plan.JoinNode.Type.RIGHT;
import static io.prestosql.sql.planner.plan.TableWriterNode.CreateTarget;
import static io.prestosql.sql.planner.plan.TableWriterNode.InsertTarget;
Expand Down Expand Up @@ -1361,6 +1362,7 @@ public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext
outputMappings.put(ordinalitySymbol.get(), channel);
channel++;
}
boolean outer = node.getJoinType() == LEFT || node.getJoinType() == FULL;
OperatorFactory operatorFactory = new UnnestOperatorFactory(
context.getNextOperatorId(),
node.getId(),
Expand All @@ -1369,7 +1371,7 @@ public PhysicalOperation visitUnnest(UnnestNode node, LocalExecutionPlanContext
unnestChannels,
unnestTypes.build(),
ordinalityType.isPresent(),
node.isOuter());
outer);
return new PhysicalOperation(operatorFactory, outputMappings.build(), context, source);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@
import static io.prestosql.sql.tree.Join.Type.CROSS;
import static io.prestosql.sql.tree.Join.Type.IMPLICIT;
import static io.prestosql.sql.tree.Join.Type.INNER;
import static io.prestosql.sql.tree.Join.Type.RIGHT;
import static java.util.Objects.requireNonNull;

class RelationPlanner
Expand Down Expand Up @@ -224,15 +223,11 @@ protected RelationPlan visitJoin(Join node, Void context)
Optional<Unnest> unnest = getUnnest(node.getRight());
if (unnest.isPresent()) {
if (node.getType() == CROSS || node.getType() == IMPLICIT) {
return planJoinUnnest(leftPlan, node, unnest.get(), false);
return planJoinUnnest(leftPlan, node, unnest.get());
}
checkState(node.getCriteria().isPresent(), "missing Join criteria");
if (node.getCriteria().get() instanceof JoinOn && ((JoinOn) node.getCriteria().get()).getExpression().equals(TRUE_LITERAL)) {
if (node.getType() == RIGHT || node.getType() == INNER) {
return planJoinUnnest(leftPlan, node, unnest.get(), false);
}
// LEFT or FULL join
return planJoinUnnest(leftPlan, node, unnest.get(), true);
return planJoinUnnest(leftPlan, node, unnest.get());
}
throw notSupportedException(unnest.get(), "UNNEST in conditional JOIN");
}
Expand Down Expand Up @@ -610,7 +605,7 @@ private static boolean isEqualComparisonExpression(Expression conjunct)
return conjunct instanceof ComparisonExpression && ((ComparisonExpression) conjunct).getOperator() == ComparisonExpression.Operator.EQUAL;
}

private RelationPlan planJoinUnnest(RelationPlan leftPlan, Join joinNode, Unnest node, boolean outer)
private RelationPlan planJoinUnnest(RelationPlan leftPlan, Join joinNode, Unnest node)
{
RelationType unnestOutputDescriptor = analysis.getOutputDescriptor(node);
// Create symbols for the result of unnesting
Expand Down Expand Up @@ -655,7 +650,33 @@ else if (type instanceof MapType) {
Optional<Symbol> ordinalitySymbol = node.isWithOrdinality() ? Optional.of(unnestedSymbolsIterator.next()) : Optional.empty();
checkState(!unnestedSymbolsIterator.hasNext(), "Not all output symbols were matched with input symbols");

UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), projectNode, leftPlan.getFieldMappings(), unnestSymbols.build(), ordinalitySymbol, outer);
Optional<Expression> filterExpression = Optional.empty();
if (joinNode.getCriteria().isPresent()) {
JoinCriteria criteria = joinNode.getCriteria().get();
if (criteria instanceof NaturalJoin) {
throw notSupportedException(joinNode, "Natural join involving UNNEST not supported");
}
if (criteria instanceof JoinUsing) {
throw notSupportedException(joinNode, "USING not supported for join involving UNNEST");
}
Expression filter = (Expression) getOnlyElement(criteria.getNodes());
if (filter.equals(TRUE_LITERAL)) {
filterExpression = Optional.of(filter);
}
else { //TODO rewrite filter to support non-trivial join criteria
throw notSupportedException(joinNode, "JOIN involving UNNEST on condition other than TRUE");
}
}

UnnestNode unnestNode = new UnnestNode(
idAllocator.getNextId(),
projectNode,
leftPlan.getFieldMappings(),
unnestSymbols.build(),
ordinalitySymbol,
JoinNode.Type.typeConvert(joinNode.getType()),
filterExpression);

return new RelationPlan(unnestNode, analysis.getScope(joinNode), unnestNode.getOutputSymbols());
}

Expand Down Expand Up @@ -757,7 +778,7 @@ else if (type instanceof MapType) {
checkState(!unnestedSymbolsIterator.hasNext(), "Not all output symbols were matched with input symbols");
ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), argumentSymbols.build(), ImmutableList.of(values.build()));

UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), valuesNode, ImmutableList.of(), unnestSymbols.build(), ordinalitySymbol, false);
UnnestNode unnestNode = new UnnestNode(idAllocator.getNextId(), valuesNode, ImmutableList.of(), unnestSymbols.build(), ordinalitySymbol, JoinNode.Type.INNER, Optional.empty());
return new RelationPlan(unnestNode, scope, unnestedSymbols);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor;
import io.prestosql.sql.tree.DereferenceExpression;
Expand Down Expand Up @@ -582,6 +583,20 @@ public PlanNode visitFilter(FilterNode node, RewriteContext<Void> context)
return new FilterNode(node.getId(), rewrittenNode.getSource(), replaceExpression(rewrittenNode.getPredicate(), mapping));
}

@Override
public PlanNode visitUnnest(UnnestNode node, RewriteContext<Void> context)
{
UnnestNode rewrittenNode = (UnnestNode) context.defaultRewrite(node);
return new UnnestNode(
node.getId(),
rewrittenNode.getSource(),
rewrittenNode.getReplicateSymbols(),
rewrittenNode.getUnnestSymbols(),
rewrittenNode.getOrdinalitySymbol(),
rewrittenNode.getJoinType(),
rewrittenNode.getFilter().map(expression -> replaceExpression(expression, mapping)));
}

@Override
public PlanNode visitValues(ValuesNode node, RewriteContext<Void> context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,8 @@ private static PlanNode addPartitioningNodes(Metadata metadata, Context context,
node.getOutputSymbols(),
ImmutableMap.of(partitionsSymbol, ImmutableList.of(partitionSymbol)),
Optional.empty(),
false);
INNER,
Optional.empty());
}

private static boolean containsNone(Collection<Symbol> values, Collection<Symbol> testValues)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,8 @@ public PlanWithProperties visitUnnest(UnnestNode node, HashComputationSet parent
.build(),
node.getUnnestSymbols(),
node.getOrdinalitySymbol(),
node.isOuter()),
node.getJoinType(),
node.getFilter()),
hashSymbols);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,11 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<Expression
public PlanNode visitUnnest(UnnestNode node, RewriteContext<Expression> context)
{
Expression inheritedPredicate = context.get();
if (node.getJoinType() == RIGHT || node.getJoinType() == FULL) {
return new FilterNode(idAllocator.getNextId(), node, inheritedPredicate);
}

//TODO for LEFT or INNER join type, push down UnnestNode's filter on replicate symbols
EqualityInference equalityInference = EqualityInference.newInstance(inheritedPredicate);

List<Expression> pushdownConjuncts = new ArrayList<>();
Expand Down Expand Up @@ -1324,7 +1328,7 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext<Expression> context)

PlanNode output = node;
if (rewrittenSource != node.getSource()) {
output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getUnnestSymbols(), node.getOrdinalitySymbol(), node.isOuter());
output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateSymbols(), node.getUnnestSymbols(), node.getOrdinalitySymbol(), node.getJoinType(), node.getFilter());
}
if (!postUnnestConjuncts.isEmpty()) {
output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postUnnestConjuncts));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,12 +682,25 @@ public ActualProperties visitUnnest(UnnestNode node, List<ActualProperties> inpu
{
Set<Symbol> passThroughInputs = ImmutableSet.copyOf(node.getReplicateSymbols());

return Iterables.getOnlyElement(inputProperties).translate(column -> {
ActualProperties translatedProperties = Iterables.getOnlyElement(inputProperties).translate(column -> {
if (passThroughInputs.contains(column)) {
return Optional.of(column);
}
return Optional.empty();
});

switch (node.getJoinType()) {
case INNER:
case LEFT:
return translatedProperties;
case RIGHT:
case FULL:
return ActualProperties.builderFrom(translatedProperties)
.local(ImmutableList.of())
.build();
default:
throw new UnsupportedOperationException("Unknown UNNEST join type: " + node.getJoinType());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,15 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext<Set<Symbol>> context
ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder()
.addAll(replicateSymbols)
.addAll(unnestSymbols.keySet());
ImmutableSet.Builder<Symbol> unnestedSymbols = ImmutableSet.builder();
for (List<Symbol> symbols : unnestSymbols.values()) {
unnestedSymbols.addAll(symbols);
}
Set<Symbol> expectedFilterSymbols = Sets.difference(SymbolsExtractor.extractUnique(node.getFilter().orElse(TRUE_LITERAL)), unnestedSymbols.build());
expectedInputs.addAll(expectedFilterSymbols);

PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());
return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol, node.isOuter());
return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol, node.getJoinType(), node.getFilter());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,23 @@ public StreamProperties visitUnnest(UnnestNode node, List<StreamProperties> inpu

// We can describe properties in terms of inputs that are projected unmodified (i.e., not the unnested symbols)
Set<Symbol> passThroughInputs = ImmutableSet.copyOf(node.getReplicateSymbols());
return properties.translate(column -> {
StreamProperties translatedProperties = properties.translate(column -> {
if (passThroughInputs.contains(column)) {
return Optional.of(column);
}
return Optional.empty();
});

switch (node.getJoinType()) {
case INNER:
case LEFT:
return translatedProperties;
case RIGHT:
case FULL:
return translatedProperties.unordered(true);
default:
throw new UnsupportedOperationException("Unknown UNNEST join type: " + node.getJoinType());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,14 @@ public PlanNode visitUnnest(UnnestNode node, RewriteContext<Void> context)
for (Map.Entry<Symbol, List<Symbol>> entry : node.getUnnestSymbols().entrySet()) {
builder.put(canonicalize(entry.getKey()), entry.getValue());
}
return new UnnestNode(node.getId(), source, canonicalizeAndDistinct(node.getReplicateSymbols()), builder.build(), node.getOrdinalitySymbol(), node.isOuter());
return new UnnestNode(
node.getId(),
source,
canonicalizeAndDistinct(node.getReplicateSymbols()),
builder.build(),
node.getOrdinalitySymbol(),
node.getJoinType(),
node.getFilter().map(this::canonicalize));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.plan.JoinNode.Type;
import io.prestosql.sql.tree.Expression;

import javax.annotation.concurrent.Immutable;

Expand All @@ -37,7 +39,8 @@ public class UnnestNode
private final List<Symbol> replicateSymbols;
private final Map<Symbol, List<Symbol>> unnestSymbols;
private final Optional<Symbol> ordinalitySymbol;
private final boolean outer;
private final Type joinType;
private final Optional<Expression> filter;

@JsonCreator
public UnnestNode(
Expand All @@ -46,7 +49,8 @@ public UnnestNode(
@JsonProperty("replicateSymbols") List<Symbol> replicateSymbols,
@JsonProperty("unnestSymbols") Map<Symbol, List<Symbol>> unnestSymbols,
@JsonProperty("ordinalitySymbol") Optional<Symbol> ordinalitySymbol,
@JsonProperty("outer") boolean outer)
@JsonProperty("joinType") Type joinType,
@JsonProperty("filter") Optional<Expression> filter)
{
super(id);
this.source = requireNonNull(source, "source is null");
Expand All @@ -60,7 +64,8 @@ public UnnestNode(
}
this.unnestSymbols = builder.build();
this.ordinalitySymbol = requireNonNull(ordinalitySymbol, "ordinalitySymbol is null");
this.outer = outer;
this.joinType = requireNonNull(joinType, "type is null");
this.filter = requireNonNull(filter, "filter is null");
}

@Override
Expand Down Expand Up @@ -98,9 +103,15 @@ public Optional<Symbol> getOrdinalitySymbol()
}

@JsonProperty
public boolean isOuter()
public Type getJoinType()
{
return outer;
return joinType;
}

@JsonProperty
public Optional<Expression> getFilter()
{
return filter;
}

@Override
Expand All @@ -118,6 +129,6 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context)
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new UnnestNode(getId(), Iterables.getOnlyElement(newChildren), replicateSymbols, unnestSymbols, ordinalitySymbol, outer);
return new UnnestNode(getId(), Iterables.getOnlyElement(newChildren), replicateSymbols, unnestSymbols, ordinalitySymbol, joinType, filter);
}
}
Loading

0 comments on commit b541da6

Please sign in to comment.