From 68823b1111191d75b9c1e3d9f93f2bec9c090cae Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Fri, 29 Mar 2019 21:24:20 +0100 Subject: [PATCH 1/4] Fix typo --- .../src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java index 615ea6b042fa..d2ee6ffae8d2 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java @@ -499,7 +499,7 @@ private PlanBuilder createPlanBuilder(Node node) private Set extractOuterColumnReferences(PlanNode planNode) { // at this point all the column references are already rewritten to SymbolReference - // when reference expression is not rewritten that means it cannot be satisfied within given PlaNode + // when reference expression is not rewritten that means it cannot be satisfied within given PlanNode // see that TranslationMap only resolves (local) fields in current scope return ExpressionExtractor.extractExpressions(planNode).stream() .flatMap(expression -> extractColumnReferences(expression, analysis.getColumnReferences()).stream()) From 61a4ff07e5c1965d8ea6e3e880ea68a86cabc928 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Fri, 29 Mar 2019 22:17:22 +0100 Subject: [PATCH 2/4] Remove unnecessary comment --- .../src/main/java/io/prestosql/sql/planner/plan/JoinNode.java | 1 - 1 file changed, 1 deletion(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/JoinNode.java index bba73f38e9fd..3da2f970711b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/JoinNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/JoinNode.java @@ -201,7 +201,6 @@ public String getJoinLabel() public static Type typeConvert(Join.Type joinType) { - // Omit SEMI join types because they must be inferred by the planner and not part of the SQL parse tree switch (joinType) { case CROSS: case IMPLICIT: From e4617f1f435c6ae46bd514592ef6fcd4051119e2 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Fri, 29 Mar 2019 22:36:19 +0100 Subject: [PATCH 3/4] Static import TRUE_LITERAL --- .../main/java/io/prestosql/sql/planner/SubqueryPlanner.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java index d2ee6ffae8d2..b9c63e890b01 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java @@ -29,7 +29,6 @@ import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.planner.plan.SimplePlanRewriter; import io.prestosql.sql.planner.plan.ValuesNode; -import io.prestosql.sql.tree.BooleanLiteral; import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; import io.prestosql.sql.tree.DereferenceExpression; import io.prestosql.sql.tree.ExistsPredicate; @@ -61,6 +60,7 @@ import static io.prestosql.sql.analyzer.SemanticExceptions.notSupportedException; import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.prestosql.sql.tree.ComparisonExpression.Operator.EQUAL; import static io.prestosql.sql.util.AstUtils.nodeContains; import static java.lang.String.format; @@ -279,7 +279,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred PlanNode subqueryPlanRoot = subqueryPlan.getRoot(); if (isAggregationWithEmptyGroupBy(subqueryPlanRoot)) { - subPlan.getTranslations().put(existsPredicate, BooleanLiteral.TRUE_LITERAL); + subPlan.getTranslations().put(existsPredicate, TRUE_LITERAL); return subPlan; } @@ -288,7 +288,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred Symbol exists = symbolAllocator.newSymbol("exists", BOOLEAN); subPlan.getTranslations().put(existsPredicate, exists); - ExistsPredicate rewrittenExistsPredicate = new ExistsPredicate(BooleanLiteral.TRUE_LITERAL); + ExistsPredicate rewrittenExistsPredicate = new ExistsPredicate(TRUE_LITERAL); return appendApplyNode( subPlan, existsPredicate.getSubquery(), From 38113c0c911ff28f2d14c5057bdbe63acb59fc1c Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 5 Mar 2019 22:37:31 +0100 Subject: [PATCH 4/4] Add support for LEFT/RIGHT/FULL/INNER lateral join --- .../sql/planner/RelationPlanner.java | 49 +++++++++++++++++-- .../sql/planner/SubqueryPlanner.java | 5 +- .../RemoveUnreferencedScalarLateralNodes.java | 5 +- .../TransformCorrelatedLateralJoinToJoin.java | 35 ++++++++----- ...formCorrelatedScalarAggregationToJoin.java | 5 +- .../TransformCorrelatedScalarSubquery.java | 6 ++- ...mCorrelatedSingleRowSubqueryToProject.java | 5 +- .../TransformExistsApplyToLateralNode.java | 2 + .../TransformUncorrelatedLateralToJoin.java | 13 ++++- .../PruneUnreferencedOutputs.java | 48 ++++++++++++++---- ...uantifiedComparisonApplyToLateralJoin.java | 1 + .../UnaliasSymbolReferences.java | 2 +- .../sql/planner/plan/LateralJoinNode.java | 36 +++++++++++++- .../prestosql/sql/planner/plan/Patterns.java | 5 ++ .../sql/planner/planprinter/PlanPrinter.java | 7 ++- .../sanity/ValidateDependenciesChecker.java | 9 ++++ .../io/prestosql/util/GraphvizPrinter.java | 10 +++- .../iterative/rule/test/PlanBuilder.java | 8 ++- .../prestosql/tests/AbstractTestQueries.java | 41 ++++++++++++---- 19 files changed, 242 insertions(+), 50 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java index 4311dda48844..e66cd20727bc 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java @@ -59,9 +59,11 @@ import io.prestosql.sql.tree.InPredicate; import io.prestosql.sql.tree.Intersect; import io.prestosql.sql.tree.Join; +import io.prestosql.sql.tree.JoinCriteria; import io.prestosql.sql.tree.JoinUsing; import io.prestosql.sql.tree.LambdaArgumentDeclaration; import io.prestosql.sql.tree.Lateral; +import io.prestosql.sql.tree.NaturalJoin; import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.sql.tree.Query; @@ -89,8 +91,10 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static io.prestosql.sql.analyzer.SemanticExceptions.notSupportedException; import static io.prestosql.sql.planner.plan.AggregationNode.singleGroupingSet; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.prestosql.sql.tree.Join.Type.INNER; import static java.util.Objects.requireNonNull; @@ -219,9 +223,6 @@ protected RelationPlan visitJoin(Join node, Void context) Optional lateral = getLateral(node.getRight()); if (lateral.isPresent()) { - if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) { - throw notSupportedException(lateral.get(), "LATERAL on other than the right side of CROSS JOIN"); - } return planLateralJoin(node, leftPlan, lateral.get()); } @@ -537,7 +538,47 @@ private RelationPlan planLateralJoin(Join join, RelationPlan leftPlan, Lateral l PlanBuilder leftPlanBuilder = initializePlanBuilder(leftPlan); PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan); - PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(leftPlanBuilder, rightPlanBuilder, lateral.getQuery(), true, LateralJoinNode.Type.INNER); + Expression filterExpression; + if (!join.getCriteria().isPresent()) { + filterExpression = TRUE_LITERAL; + } + else { + JoinCriteria criteria = join.getCriteria().get(); + if (criteria instanceof JoinUsing || criteria instanceof NaturalJoin) { + throw notSupportedException(join, "Lateral join with criteria other than ON"); + } + filterExpression = (Expression) getOnlyElement(criteria.getNodes()); + } + + List rewriterOutputSymbols = ImmutableList.builder() + .addAll(leftPlan.getFieldMappings()) + .addAll(rightPlan.getFieldMappings()) + .build(); + + // this node is not used in the plan. It is only used for creating the TranslationMap. + PlanNode dummy = new ValuesNode( + idAllocator.getNextId(), + ImmutableList.builder() + .addAll(leftPlanBuilder.getRoot().getOutputSymbols()) + .addAll(rightPlanBuilder.getRoot().getOutputSymbols()) + .build(), + ImmutableList.of()); + + RelationPlan intermediateRelationPlan = new RelationPlan(dummy, analysis.getScope(join), rewriterOutputSymbols); + TranslationMap translationMap = new TranslationMap(intermediateRelationPlan, analysis, lambdaDeclarationToSymbolMap); + translationMap.setFieldMappings(rewriterOutputSymbols); + translationMap.putExpressionMappingsFrom(leftPlanBuilder.getTranslations()); + translationMap.putExpressionMappingsFrom(rightPlanBuilder.getTranslations()); + + Expression rewrittenFilterCondition = translationMap.rewrite(filterExpression); + + PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin( + leftPlanBuilder, + rightPlanBuilder, + lateral.getQuery(), + true, + LateralJoinNode.Type.typeConvert(join.getType()), + rewrittenFilterCondition); List outputSymbols = ImmutableList.builder() .addAll(leftPlan.getRoot().getOutputSymbols()) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java index b9c63e890b01..e0e5e56bda42 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java @@ -227,10 +227,10 @@ private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder subPlan, SubqueryE } // The subquery's EnforceSingleRowNode always produces a row, so the join is effectively INNER - return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed, LateralJoinNode.Type.INNER); + return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed, LateralJoinNode.Type.INNER, TRUE_LITERAL); } - public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed, LateralJoinNode.Type type) + public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed, LateralJoinNode.Type type, Expression filterCondition) { PlanNode subqueryNode = subqueryPlan.getRoot(); Map correlation = extractCorrelation(subPlan, subqueryNode); @@ -247,6 +247,7 @@ public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPl subqueryNode, ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values())), type, + filterCondition, query), analysis.getParameters()); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java index 1cc061f901f4..39155de646c2 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnreferencedScalarLateralNodes.java @@ -21,12 +21,15 @@ import io.prestosql.sql.planner.plan.PlanNode; import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar; +import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter; import static io.prestosql.sql.planner.plan.Patterns.lateralJoin; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; public class RemoveUnreferencedScalarLateralNodes implements Rule { - private static final Pattern PATTERN = lateralJoin(); + private static final Pattern PATTERN = lateralJoin() + .with(filter().equalTo(TRUE_LITERAL)); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java index 9aeadf98b1f4..8bd9033c2142 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedLateralJoinToJoin.java @@ -22,12 +22,15 @@ import io.prestosql.sql.planner.plan.JoinNode; import io.prestosql.sql.planner.plan.LateralJoinNode; import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.tree.Expression; import java.util.Optional; import static io.prestosql.matching.Pattern.nonEmpty; +import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation; import static io.prestosql.sql.planner.plan.Patterns.lateralJoin; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; /** * Tries to decorrelate subquery and rewrite it using normal join. @@ -53,18 +56,24 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup()); Optional decorrelatedNodeOptional = planNodeDecorrelator.decorrelateFilters(subquery, lateralJoinNode.getCorrelation()); - return decorrelatedNodeOptional.map(decorrelatedNode -> - Result.ofPlanNode(new JoinNode( - context.getIdAllocator().getNextId(), - lateralJoinNode.getType().toJoinNodeType(), - lateralJoinNode.getInput(), - decorrelatedNode.getNode(), - ImmutableList.of(), - lateralJoinNode.getOutputSymbols(), - decorrelatedNode.getCorrelatedPredicates(), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()))).orElseGet(Result::empty); + return decorrelatedNodeOptional + .map(decorrelatedNode -> { + Expression joinFilter = combineConjuncts( + decorrelatedNode.getCorrelatedPredicates().orElse(TRUE_LITERAL), + lateralJoinNode.getFilter()); + return Result.ofPlanNode(new JoinNode( + context.getIdAllocator().getNextId(), + lateralJoinNode.getType().toJoinNodeType(), + lateralJoinNode.getInput(), + decorrelatedNode.getNode(), + ImmutableList.of(), + lateralJoinNode.getOutputSymbols(), + joinFilter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(joinFilter), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty())); + }) + .orElseGet(Result::empty); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java index 205d7732a153..7c9861b22e33 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java @@ -31,7 +31,9 @@ import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation; +import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter; import static io.prestosql.sql.planner.plan.Patterns.lateralJoin; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.prestosql.util.MorePredicates.isInstanceOfAny; import static java.util.Objects.requireNonNull; @@ -67,7 +69,8 @@ public class TransformCorrelatedScalarAggregationToJoin implements Rule { private static final Pattern PATTERN = lateralJoin() - .with(nonEmpty(correlation())); + .with(nonEmpty(correlation())) + .with(filter().equalTo(TRUE_LITERAL)); // todo non-trivial join filter: adding filter/project on top of aggregation @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java index 8bd074110f4e..64e1126af4b0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -46,6 +46,7 @@ import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality; import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.LEFT; import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation; +import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter; import static io.prestosql.sql.planner.plan.Patterns.lateralJoin; import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -81,7 +82,8 @@ public class TransformCorrelatedScalarSubquery implements Rule { private static final Pattern PATTERN = lateralJoin() - .with(nonEmpty(correlation())); + .with(nonEmpty(correlation())) + .with(filter().equalTo(TRUE_LITERAL)); @Override public Pattern getPattern() @@ -116,6 +118,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context rewrittenSubquery, lateralJoinNode.getCorrelation(), producesSingleRow ? lateralJoinNode.getType() : LEFT, + lateralJoinNode.getFilter(), lateralJoinNode.getOriginSubquery())); } @@ -130,6 +133,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context rewrittenSubquery, lateralJoinNode.getCorrelation(), LEFT, + lateralJoinNode.getFilter(), lateralJoinNode.getOriginSubquery()); Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BooleanType.BOOLEAN); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java index 23413e22a10e..de97c2befa17 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java @@ -25,7 +25,9 @@ import java.util.List; import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter; import static io.prestosql.sql.planner.plan.Patterns.lateralJoin; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; /** * This optimizer can rewrite correlated single row subquery to projection in a way described here: @@ -47,7 +49,8 @@ public class TransformCorrelatedSingleRowSubqueryToProject implements Rule { - private static final Pattern PATTERN = lateralJoin(); + private static final Pattern PATTERN = lateralJoin() + .with(filter().equalTo(TRUE_LITERAL)); @Override public Pattern getPattern() diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 028fb4e7c9ab..5625b3075b5b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -145,6 +145,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C subquery, applyNode.getCorrelation(), LEFT, + TRUE_LITERAL, applyNode.getOriginSubquery()), assignments.build())); } @@ -171,6 +172,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))), parent.getCorrelation(), INNER, + TRUE_LITERAL, parent.getOriginSubquery()); } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java index 58066a1698b3..708661d87b88 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformUncorrelatedLateralToJoin.java @@ -20,12 +20,14 @@ import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.plan.JoinNode; import io.prestosql.sql.planner.plan.LateralJoinNode; +import io.prestosql.sql.tree.Expression; import java.util.Optional; import static io.prestosql.matching.Pattern.empty; import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation; import static io.prestosql.sql.planner.plan.Patterns.lateralJoin; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; public class TransformUncorrelatedLateralToJoin implements Rule @@ -52,10 +54,19 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context .addAll(lateralJoinNode.getInput().getOutputSymbols()) .addAll(lateralJoinNode.getSubquery().getOutputSymbols()) .build(), - Optional.empty(), + filter(lateralJoinNode.getFilter()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())); } + + private Optional filter(Expression lateralJoinFilter) + { + if (lateralJoinFilter.equals(TRUE_LITERAL)) { + return Optional.empty(); + } + + return Optional.of(lateralJoinFilter); + } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java index f568bb10b3b8..fd1a124a77c5 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -86,7 +86,12 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Sets.intersection; +import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar; +import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.INNER; +import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.LEFT; +import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.RIGHT; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static java.util.Objects.requireNonNull; /** @@ -812,11 +817,25 @@ public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext> context) { - PlanNode subquery = context.rewrite(node.getSubquery(), context.get()); + Set expectedFilterSymbols = SymbolsExtractor.extractUnique(node.getFilter()); + + Set expectedFilterAndContextSymbols = ImmutableSet.builder() + .addAll(expectedFilterSymbols) + .addAll(context.get()) + .build(); + + PlanNode subquery = context.rewrite(node.getSubquery(), expectedFilterAndContextSymbols); // remove unused lateral nodes - if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty() && isScalar(subquery)) { - return context.rewrite(node.getInput(), context.get()); + if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty()) { + // remove unused lateral subquery of inner join + if (node.getType() == INNER && isScalar(subquery) && node.getFilter().equals(TRUE_LITERAL)) { + return context.rewrite(node.getInput(), context.get()); + } + // remove unused lateral subquery of left join + if (node.getType() == LEFT && isAtMostScalar(subquery)) { + return context.rewrite(node.getInput(), context.get()); + } } // prune not used correlation symbols @@ -825,18 +844,29 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext inputContext = ImmutableSet.builder() - .addAll(context.get()) + Set expectedCorrelationAndContextSymbols = ImmutableSet.builder() .addAll(newCorrelation) + .addAll(context.get()) + .build(); + Set inputContext = ImmutableSet.builder() + .addAll(expectedCorrelationAndContextSymbols) + .addAll(expectedFilterSymbols) .build(); PlanNode input = context.rewrite(node.getInput(), inputContext); - // remove unused lateral nodes - if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), inputContext).isEmpty() && isScalar(input)) { - return subquery; + // remove unused input nodes + if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), expectedCorrelationAndContextSymbols).isEmpty()) { + // remove unused input of inner join + if (node.getType() == INNER && isScalar(input) && node.getFilter().equals(TRUE_LITERAL)) { + return subquery; + } + // remove unused input of right join + if (node.getType() == RIGHT && isAtMostScalar(input)) { + return subquery; + } } - return new LateralJoinNode(node.getId(), input, subquery, newCorrelation, node.getType(), node.getOriginSubquery()); + return new LateralJoinNode(node.getId(), input, subquery, newCorrelation, node.getType(), node.getFilter(), node.getOriginSubquery()); } } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index ae451abb397a..3ee82bdf775d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -174,6 +174,7 @@ countNonNullValue, new Aggregation( subqueryPlan, node.getCorrelation(), LateralJoinNode.Type.INNER, + TRUE_LITERAL, node.getOriginSubquery()); Expression valueComparedToSubquery = rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java index c35919cf46d8..61fdd78d8002 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -454,7 +454,7 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext cont PlanNode subquery = context.rewrite(node.getSubquery()); List canonicalCorrelation = canonicalizeAndDistinct(node.getCorrelation()); - return new LateralJoinNode(node.getId(), source, subquery, canonicalCorrelation, node.getType(), node.getOriginSubquery()); + return new LateralJoinNode(node.getId(), source, subquery, canonicalCorrelation, node.getType(), canonicalize(node.getFilter()), node.getOriginSubquery()); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/LateralJoinNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/LateralJoinNode.java index c31141458331..7784cd05fd2f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/LateralJoinNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/LateralJoinNode.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.Join; import io.prestosql.sql.tree.Node; import javax.annotation.concurrent.Immutable; @@ -40,7 +42,9 @@ public class LateralJoinNode public enum Type { INNER(JoinNode.Type.INNER), - LEFT(JoinNode.Type.LEFT); + LEFT(JoinNode.Type.LEFT), + RIGHT(JoinNode.Type.RIGHT), + FULL(JoinNode.Type.FULL); Type(JoinNode.Type joinNodeType) { @@ -53,6 +57,24 @@ public JoinNode.Type toJoinNodeType() { return joinNodeType; } + + public static Type typeConvert(Join.Type joinType) + { + switch (joinType) { + case CROSS: + case IMPLICIT: + case INNER: + return Type.INNER; + case LEFT: + return Type.LEFT; + case RIGHT: + return Type.RIGHT; + case FULL: + return Type.FULL; + default: + throw new UnsupportedOperationException("Unsupported join type: " + joinType); + } + } } private final PlanNode input; @@ -63,6 +85,7 @@ public JoinNode.Type toJoinNodeType() */ private final List correlation; private final Type type; + private final Expression filter; /** * HACK! @@ -77,12 +100,14 @@ public LateralJoinNode( @JsonProperty("subquery") PlanNode subquery, @JsonProperty("correlation") List correlation, @JsonProperty("type") Type type, + @JsonProperty("filter") Expression filter, @JsonProperty("originSubquery") Node originSubquery) { super(id); requireNonNull(input, "input is null"); requireNonNull(subquery, "right is null"); requireNonNull(correlation, "correlation is null"); + requireNonNull(filter, "filter is null"); requireNonNull(originSubquery, "originSubquery is null"); checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); @@ -91,6 +116,7 @@ public LateralJoinNode( this.subquery = subquery; this.correlation = ImmutableList.copyOf(correlation); this.type = type; + this.filter = filter; this.originSubquery = originSubquery; } @@ -118,6 +144,12 @@ public Type getType() return type; } + @JsonProperty("filter") + public Expression getFilter() + { + return filter; + } + @JsonProperty("originSubquery") public Node getOriginSubquery() { @@ -144,7 +176,7 @@ public List getOutputSymbols() public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 2, "expected newChildren to contain 2 nodes"); - return new LateralJoinNode(getId(), newChildren.get(0), newChildren.get(1), correlation, type, originSubquery); + return new LateralJoinNode(getId(), newChildren.get(0), newChildren.get(1), correlation, type, filter, originSubquery); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java index 15f5e15cd5c2..41d45e2e0459 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java @@ -232,6 +232,11 @@ public static Property subquery() { return property("subquery", LateralJoinNode::getSubquery); } + + public static Property filter() + { + return property("filter", LateralJoinNode::getFilter); + } } public static class Limit diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java b/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java index 661448baa9a2..a3e42d93bb4b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/planprinter/PlanPrinter.java @@ -128,6 +128,7 @@ import static io.prestosql.sql.planner.planprinter.TextRenderer.formatDouble; import static io.prestosql.sql.planner.planprinter.TextRenderer.formatPositions; import static io.prestosql.sql.planner.planprinter.TextRenderer.indentString; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; @@ -1043,7 +1044,11 @@ public Void visitApply(ApplyNode node, Void context) @Override public Void visitLateralJoin(LateralJoinNode node, Void context) { - addNode(node, "Lateral", format("[%s]", node.getCorrelation())); + addNode(node, + "Lateral", + format("[%s%s]", + node.getCorrelation(), + node.getFilter().equals(TRUE_LITERAL) ? "" : " " + node.getFilter())); return processChildren(node, context); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java index ea0949bbc717..421c1583e9c0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/ValidateDependenciesChecker.java @@ -646,6 +646,15 @@ public Void visitLateralJoin(LateralJoinNode node, Set boundSymbols) node.getCorrelation(), "not all LATERAL correlation symbols are used in subquery"); + Set inputs = ImmutableSet.builder() + .addAll(createInputs(node.getInput(), boundSymbols)) + .addAll(createInputs(node.getSubquery(), boundSymbols)) + .build(); + + Set filterSymbols = SymbolsExtractor.extractUnique(node.getFilter()); + + checkDependencies(inputs, filterSymbols, "filter symbols (%s) not in sources (%s)", filterSymbols, inputs); + return null; } diff --git a/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java b/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java index 4fe48ba84d7e..ab9d5e3dea20 100644 --- a/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/io/prestosql/util/GraphvizPrinter.java @@ -72,6 +72,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Maps.immutableEnumMap; import static io.prestosql.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static java.lang.String.format; public final class GraphvizPrinter @@ -514,8 +515,13 @@ public Void visitAssignUniqueId(AssignUniqueId node, Void context) @Override public Void visitLateralJoin(LateralJoinNode node, Void context) { - String parameters = Joiner.on(",").join(node.getCorrelation()); - printNode(node, "LateralJoin", parameters, NODE_COLORS.get(NodeType.JOIN)); + String correlationSymbols = Joiner.on(",").join(node.getCorrelation()); + String filterExpression = ""; + if (!node.getFilter().equals(TRUE_LITERAL)) { + filterExpression = " " + node.getFilter().toString(); + } + + printNode(node, "LateralJoin", correlationSymbols + filterExpression, NODE_COLORS.get(NodeType.JOIN)); node.getInput().accept(this, context); node.getSubquery().accept(this, context); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java index 2228a3cb7945..0f9ddf2300db 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java @@ -97,6 +97,7 @@ import static io.prestosql.spi.type.VarbinaryType.VARBINARY; import static io.prestosql.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.prestosql.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.prestosql.util.MoreLists.nElements; import static java.lang.String.format; import static java.util.Collections.emptyList; @@ -362,9 +363,14 @@ public AssignUniqueId assignUniqueId(Symbol unique, PlanNode source) } public LateralJoinNode lateral(List correlation, PlanNode input, PlanNode subquery) + { + return lateral(correlation, input, LateralJoinNode.Type.INNER, TRUE_LITERAL, subquery); + } + + public LateralJoinNode lateral(List correlation, PlanNode input, LateralJoinNode.Type type, Expression filter, PlanNode subquery) { NullLiteral originSubquery = new NullLiteral(); // does not matter for tests - return new LateralJoinNode(idAllocator.getNextId(), input, subquery, correlation, LateralJoinNode.Type.INNER, originSubquery); + return new LateralJoinNode(idAllocator.getNextId(), input, subquery, correlation, type, filter, originSubquery); } public TableScanNode tableScan(List symbols, Map assignments) diff --git a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java index 47d1cd712513..f2936c142be9 100644 --- a/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/io/prestosql/tests/AbstractTestQueries.java @@ -4993,6 +4993,7 @@ public void testLateralJoin() "SELECT name FROM nation, LATERAL (SELECT 1 WHERE false)", "SELECT 1 WHERE false"); + // unused scalar subquery is removed assertQuery( "SELECT name FROM nation, LATERAL (SELECT 1)", "SELECT name FROM nation"); @@ -5001,6 +5002,21 @@ public void testLateralJoin() "SELECT name FROM nation, LATERAL (SELECT 1 WHERE name = 'ola')", "SELECT 1 WHERE false"); + // unused at-most-scalar subquery is removed + assertQuery( + "SELECT name FROM nation LEFT JOIN LATERAL (SELECT 1 WHERE name = 'ola') ON true", + "SELECT name FROM nation"); + + // unused scalar input is removed + assertQuery( + "SELECT n FROM (VALUES 1) t(a), LATERAL (SELECT name FROM region) r(n)", + "SELECT name FROM region"); + + // unused at-most-scalar input is removed + assertQuery( + "SELECT n FROM (SELECT 1 FROM (VALUES 1) WHERE rand() = 5) t(a) RIGHT JOIN LATERAL (SELECT name FROM region) r(n) ON true", + "SELECT name FROM region"); + assertQuery( "SELECT nationkey, a FROM nation, LATERAL (SELECT max(region.name) FROM region WHERE region.regionkey <= nation.regionkey) t(a) ORDER BY nationkey LIMIT 1", "VALUES (0, 'AFRICA')"); @@ -5041,16 +5057,21 @@ public void testLateralJoin() assertQuery( "SELECT * FROM (VALUES 2) a(x) CROSS JOIN LATERAL(SELECT x, x + 1)", "SELECT 2, 2, 3"); - - assertQueryFails( - "SELECT * FROM (VALUES array[2, 2]) a(x) LEFT OUTER JOIN LATERAL(VALUES x) ON true", - "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); - assertQueryFails( - "SELECT * FROM (VALUES array[2, 2]) a(x) RIGHT OUTER JOIN LATERAL(VALUES x) ON true", - "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); - assertQueryFails( - "SELECT * FROM (VALUES array[2, 2]) a(x) FULL OUTER JOIN LATERAL(VALUES x) ON true", - "line .*: LATERAL on other than the right side of CROSS JOIN is not supported"); + assertQuery( + "SELECT r.name, a FROM region r LEFT JOIN LATERAL (SELECT name FROM nation WHERE r.regionkey = nation.regionkey) n(a) ON r.name > a ORDER BY r.name LIMIT 1", + "SELECT 'AFRICA', NULL"); + assertQuery( + "SELECT r.name, a FROM region r RIGHT JOIN LATERAL (SELECT name FROM nation WHERE r.regionkey = nation.regionkey) n(a) ON r.name > a ORDER BY a LIMIT 1", + "SELECT NULL, 'ALGERIA'"); + assertQuery( + "SELECT * FROM (VALUES 1) a(x) FULL JOIN LATERAL(SELECT y FROM (VALUES 2) b(y) WHERE y > x) ON x=y", + "VALUES (1, NULL), (NULL, 2)"); + assertQuery( + "SELECT * FROM (VALUES 1, 2, 3) a(x) FULL JOIN LATERAL(SELECT z FROM (VALUES 1, 2, 3, 5) b(z) WHERE z != x) ON x != 1 AND z != 5", + "VALUES (1, NULL), (2, 3), (2, 1), (3, 2), (3, 1), (NULL, 5)"); + assertQuery( + "SELECT * FROM (VALUES 1, 2) a(x) JOIN LATERAL(SELECT y FROM (VALUES 2, 3) b(y) WHERE y > x) c(z) ON z > 2*x", + "VALUES (1, 3)"); } @Test