Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for LEFT/RIGHT/FULL/INNER lateral join #390

Merged
merged 4 commits into from
Apr 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@
import io.prestosql.sql.tree.InPredicate;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it could be hard, but can divide this commit to first introduce the expression into LateralJoinNode and then we could have a commit which would add RIGHT and FULL to LateralJoinType.

import io.prestosql.sql.tree.Intersect;
import io.prestosql.sql.tree.Join;
import io.prestosql.sql.tree.JoinCriteria;
import io.prestosql.sql.tree.JoinUsing;
import io.prestosql.sql.tree.LambdaArgumentDeclaration;
import io.prestosql.sql.tree.Lateral;
import io.prestosql.sql.tree.NaturalJoin;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.Query;
Expand Down Expand Up @@ -89,8 +91,10 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.prestosql.sql.analyzer.SemanticExceptions.notSupportedException;
import static io.prestosql.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.prestosql.sql.tree.Join.Type.INNER;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -219,9 +223,6 @@ protected RelationPlan visitJoin(Join node, Void context)

Optional<Lateral> lateral = getLateral(node.getRight());
if (lateral.isPresent()) {
if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) {
throw notSupportedException(lateral.get(), "LATERAL on other than the right side of CROSS JOIN");
}
return planLateralJoin(node, leftPlan, lateral.get());
}

Expand Down Expand Up @@ -537,7 +538,47 @@ private RelationPlan planLateralJoin(Join join, RelationPlan leftPlan, Lateral l
PlanBuilder leftPlanBuilder = initializePlanBuilder(leftPlan);
PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan);

PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(leftPlanBuilder, rightPlanBuilder, lateral.getQuery(), true, LateralJoinNode.Type.INNER);
Expression filterExpression;
if (!join.getCriteria().isPresent()) {
filterExpression = TRUE_LITERAL;
}
else {
JoinCriteria criteria = join.getCriteria().get();
if (criteria instanceof JoinUsing || criteria instanceof NaturalJoin) {
throw notSupportedException(join, "Lateral join with criteria other than ON");
}
filterExpression = (Expression) getOnlyElement(criteria.getNodes());
}

List<Symbol> rewriterOutputSymbols = ImmutableList.<Symbol>builder()
.addAll(leftPlan.getFieldMappings())
.addAll(rightPlan.getFieldMappings())
.build();

// this node is not used in the plan. It is only used for creating the TranslationMap.
PlanNode dummy = new ValuesNode(
idAllocator.getNextId(),
ImmutableList.<Symbol>builder()
.addAll(leftPlanBuilder.getRoot().getOutputSymbols())
.addAll(rightPlanBuilder.getRoot().getOutputSymbols())
.build(),
ImmutableList.of());

RelationPlan intermediateRelationPlan = new RelationPlan(dummy, analysis.getScope(join), rewriterOutputSymbols);
TranslationMap translationMap = new TranslationMap(intermediateRelationPlan, analysis, lambdaDeclarationToSymbolMap);
translationMap.setFieldMappings(rewriterOutputSymbols);
translationMap.putExpressionMappingsFrom(leftPlanBuilder.getTranslations());
translationMap.putExpressionMappingsFrom(rightPlanBuilder.getTranslations());

Expression rewrittenFilterCondition = translationMap.rewrite(filterExpression);

PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(
leftPlanBuilder,
rightPlanBuilder,
lateral.getQuery(),
true,
LateralJoinNode.Type.typeConvert(join.getType()),
rewrittenFilterCondition);

List<Symbol> outputSymbols = ImmutableList.<Symbol>builder()
.addAll(leftPlan.getRoot().getOutputSymbols())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor;
import io.prestosql.sql.tree.DereferenceExpression;
import io.prestosql.sql.tree.ExistsPredicate;
Expand Down Expand Up @@ -61,6 +60,7 @@
import static io.prestosql.sql.analyzer.SemanticExceptions.notSupportedException;
import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression;
import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.prestosql.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.prestosql.sql.util.AstUtils.nodeContains;
import static java.lang.String.format;
Expand Down Expand Up @@ -227,10 +227,10 @@ private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder subPlan, SubqueryE
}

// The subquery's EnforceSingleRowNode always produces a row, so the join is effectively INNER
return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed, LateralJoinNode.Type.INNER);
return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed, LateralJoinNode.Type.INNER, TRUE_LITERAL);
}

public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed, LateralJoinNode.Type type)
public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed, LateralJoinNode.Type type, Expression filterCondition)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Joins don't have "filter" conditions. Rather, it's a "join" condition or criteria. The distinction is subtle but important. A filter determines whether a row is preserved or removed, but the join condition determines whether the rows are considered as candidates to be joined. Depending on the type of join, rows that don't satisfy the join condition may be dropped or joined with a synthetic row containing nulls (e.g., for outer joins). So I'd rename this argument to just criteria.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware of the naming bias.
However, JoinNode has got both members: criteria and filter.
I chose the name filter for LateralJoinNode member, because it is actually the counterpart of JoinNode's filter and it becomes JoinNode's filter later in the course of planning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Ok, let's leave it for now. We can clean this up in the future for both LateralJoinNode and JoinNode

{
PlanNode subqueryNode = subqueryPlan.getRoot();
Map<Expression, Expression> correlation = extractCorrelation(subPlan, subqueryNode);
Expand All @@ -247,6 +247,7 @@ public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPl
subqueryNode,
ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values())),
type,
filterCondition,
query),
analysis.getParameters());
}
Expand Down Expand Up @@ -279,7 +280,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred

PlanNode subqueryPlanRoot = subqueryPlan.getRoot();
if (isAggregationWithEmptyGroupBy(subqueryPlanRoot)) {
subPlan.getTranslations().put(existsPredicate, BooleanLiteral.TRUE_LITERAL);
subPlan.getTranslations().put(existsPredicate, TRUE_LITERAL);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separete commit

return subPlan;
}

Expand All @@ -288,7 +289,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred

Symbol exists = symbolAllocator.newSymbol("exists", BOOLEAN);
subPlan.getTranslations().put(existsPredicate, exists);
ExistsPredicate rewrittenExistsPredicate = new ExistsPredicate(BooleanLiteral.TRUE_LITERAL);
ExistsPredicate rewrittenExistsPredicate = new ExistsPredicate(TRUE_LITERAL);
return appendApplyNode(
subPlan,
existsPredicate.getSubquery(),
Expand Down Expand Up @@ -499,7 +500,7 @@ private PlanBuilder createPlanBuilder(Node node)
private Set<Expression> extractOuterColumnReferences(PlanNode planNode)
{
// at this point all the column references are already rewritten to SymbolReference
// when reference expression is not rewritten that means it cannot be satisfied within given PlaNode
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this fix to a separate commit, since it's unrelated to this change.

// when reference expression is not rewritten that means it cannot be satisfied within given PlanNode
// see that TranslationMap only resolves (local) fields in current scope
return ExpressionExtractor.extractExpressions(planNode).stream()
.flatMap(expression -> extractColumnReferences(expression, analysis.getColumnReferences()).stream())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import io.prestosql.sql.planner.plan.PlanNode;

import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

public class RemoveUnreferencedScalarLateralNodes
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin();
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(filter().equalTo(TRUE_LITERAL));

@Override
public Pattern<LateralJoinNode> getPattern()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.tree.Expression;

import java.util.Optional;

import static io.prestosql.matching.Pattern.nonEmpty;
import static io.prestosql.sql.ExpressionUtils.combineConjuncts;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

/**
* Tries to decorrelate subquery and rewrite it using normal join.
Expand All @@ -53,18 +56,24 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup());
Optional<DecorrelatedNode> decorrelatedNodeOptional = planNodeDecorrelator.decorrelateFilters(subquery, lateralJoinNode.getCorrelation());

return decorrelatedNodeOptional.map(decorrelatedNode ->
Result.ofPlanNode(new JoinNode(
context.getIdAllocator().getNextId(),
lateralJoinNode.getType().toJoinNodeType(),
lateralJoinNode.getInput(),
decorrelatedNode.getNode(),
ImmutableList.of(),
lateralJoinNode.getOutputSymbols(),
decorrelatedNode.getCorrelatedPredicates(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty()))).orElseGet(Result::empty);
return decorrelatedNodeOptional
.map(decorrelatedNode -> {
Expression joinFilter = combineConjuncts(
decorrelatedNode.getCorrelatedPredicates().orElse(TRUE_LITERAL),
lateralJoinNode.getFilter());
return Result.ofPlanNode(new JoinNode(
context.getIdAllocator().getNextId(),
lateralJoinNode.getType().toJoinNodeType(),
lateralJoinNode.getInput(),
decorrelatedNode.getNode(),
ImmutableList.of(),
lateralJoinNode.getOutputSymbols(),
joinFilter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(joinFilter),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty()));
})
.orElseGet(Result::empty);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.prestosql.util.MorePredicates.isInstanceOfAny;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -67,7 +69,8 @@ public class TransformCorrelatedScalarAggregationToJoin
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(nonEmpty(correlation()));
.with(nonEmpty(correlation()))
.with(filter().equalTo(TRUE_LITERAL)); // todo non-trivial join filter: adding filter/project on top of aggregation

@Override
public Pattern<LateralJoinNode> getPattern()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.LEFT;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

Expand Down Expand Up @@ -81,7 +82,8 @@ public class TransformCorrelatedScalarSubquery
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(nonEmpty(correlation()));
.with(nonEmpty(correlation()))
.with(filter().equalTo(TRUE_LITERAL));

@Override
public Pattern getPattern()
Expand Down Expand Up @@ -116,6 +118,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
rewrittenSubquery,
lateralJoinNode.getCorrelation(),
producesSingleRow ? lateralJoinNode.getType() : LEFT,
lateralJoinNode.getFilter(),
lateralJoinNode.getOriginSubquery()));
}

Expand All @@ -130,6 +133,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
rewrittenSubquery,
lateralJoinNode.getCorrelation(),
LEFT,
lateralJoinNode.getFilter(),
lateralJoinNode.getOriginSubquery());

Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BooleanType.BOOLEAN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import java.util.List;

import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

/**
* This optimizer can rewrite correlated single row subquery to projection in a way described here:
Expand All @@ -47,7 +49,8 @@
public class TransformCorrelatedSingleRowSubqueryToProject
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin();
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(filter().equalTo(TRUE_LITERAL));

@Override
public Pattern<LateralJoinNode> getPattern()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, C
subquery,
applyNode.getCorrelation(),
LEFT,
TRUE_LITERAL,
applyNode.getOriginSubquery()),
assignments.build()));
}
Expand All @@ -171,6 +172,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context)
Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))),
parent.getCorrelation(),
INNER,
TRUE_LITERAL,
parent.getOriginSubquery());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.tree.Expression;

import java.util.Optional;

import static io.prestosql.matching.Pattern.empty;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

public class TransformUncorrelatedLateralToJoin
implements Rule<LateralJoinNode>
Expand All @@ -52,10 +54,19 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
.addAll(lateralJoinNode.getInput().getOutputSymbols())
.addAll(lateralJoinNode.getSubquery().getOutputSymbols())
.build(),
Optional.empty(),
filter(lateralJoinNode.getFilter()),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty()));
}

private Optional<Expression> filter(Expression lateralJoinFilter)
{
if (lateralJoinFilter.equals(TRUE_LITERAL)) {
return Optional.empty();
}

return Optional.of(lateralJoinFilter);
}
}
Loading