diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java index 59ead81e287d..86d392787565 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java @@ -58,7 +58,8 @@ /** * This rule decorrelates a correlated subquery of INNER correlated join with: * - single grouped aggregation, or - * - grouped aggregation over distinct operator (grouped aggregation with no aggregation assignments) + * - grouped aggregation over distinct operator (grouped aggregation with no aggregation assignments), + * in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator *
* In the case of single aggregation, it transforms: *
@@ -141,19 +142,24 @@ public PatterngetPattern() @Override public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { - // if there is another aggregation below the AggregationNode, handle both PlanNode source = captures.get(SOURCE); + + // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator AggregationNode distinct = null; - if (isDistinctOperator(source)) { - distinct = (AggregationNode) source; - source = distinct.getSource(); - } // decorrelate nested plan PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); Optional decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); if (decorrelatedSource.isEmpty()) { - return Result.empty(); + // we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan + if (isDistinctOperator(source)) { + distinct = (AggregationNode) source; + source = distinct.getSource(); + decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + } + if (decorrelatedSource.isEmpty()) { + return Result.empty(); + } } source = decorrelatedSource.get().getNode(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java index d87919629260..bcfafb67889d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java @@ -53,7 +53,8 @@ /** * This rule decorrelates a correlated subquery of INNER correlated join with: * - single grouped aggregation, or - * - grouped aggregation over distinct operator (grouped aggregation with no aggregation assignments) + * - grouped aggregation over distinct operator (grouped aggregation with no aggregation assignments), + * in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator * It is similar to TransformCorrelatedGroupedAggregationWithProjection rule, but does not support projection over aggregation in the subquery * * In the case of single aggregation, it transforms: @@ -132,19 +133,24 @@ public Pattern
getPattern() @Override public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) { - // if there is another aggregation below the AggregationNode, handle both PlanNode source = captures.get(SOURCE); + + // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator AggregationNode distinct = null; - if (isDistinctOperator(source)) { - distinct = (AggregationNode) source; - source = distinct.getSource(); - } // decorrelate nested plan PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup()); Optional decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); if (decorrelatedSource.isEmpty()) { - return Result.empty(); + // we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan + if (isDistinctOperator(source)) { + distinct = (AggregationNode) source; + source = distinct.getSource(); + decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + } + if (decorrelatedSource.isEmpty()) { + return Result.empty(); + } } source = decorrelatedSource.get().getNode(); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 0766a04d4280..2deb244eb10c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -995,6 +995,38 @@ public void testCorrelatedDistinctAggregationRewriteToLeftOuterJoin() anyTree(node(ValuesNode.class)))))); } + @Test + public void testCorrelatedDistinctGropuedAggregationRewriteToLeftOuterJoin() + { + assertPlan( + "SELECT (SELECT count(DISTINCT o.orderkey) FROM orders o WHERE c.custkey = o.custkey GROUP BY o.orderstatus), c.custkey FROM customer c", + output( + project(filter( + "(CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(28, 'Scalar sub-query has returned multiple rows') AS boolean) END)", + project(markDistinct( + "is_distinct", + ImmutableList.of("unique"), + join( + LEFT, + ImmutableList.of(equiJoinClause("c_custkey", "o_custkey")), + project(assignUniqueId( + "unique", + tableScan("customer", ImmutableMap.of("c_custkey", "custkey")))), + project(aggregation( + singleGroupingSet("o_orderstatus", "o_custkey"), + ImmutableMap.of(Optional.of("count"), functionCall("count", ImmutableList.of("o_orderkey"))), + Optional.empty(), + SINGLE, + project(aggregation( + singleGroupingSet("o_orderstatus", "o_orderkey", "o_custkey"), + ImmutableMap.of(), + Optional.empty(), + FINAL, + anyTree(tableScan( + "orders", + ImmutableMap.of("o_orderkey", "orderkey", "o_orderstatus", "orderstatus", "o_custkey", "custkey")))))))))))))); + } + @Test public void testRemovesTrivialFilters() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java index f300abb305e1..42e9157123c9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithProjection.java @@ -146,4 +146,50 @@ public void rewritesOnSubqueryWithDistinct() "true", values("a", "b"))))))); } + + @Test + public void rewritesOnSubqueryWithDecorrelatableDistinct() + { + // distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator + // because the correlated predicate is equality comparison + tester().assertThat(new TransformCorrelatedGroupedAggregationWithProjection(tester().getPlannerContext())) + .on(p -> p.correlatedJoin( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + INNER, + PlanBuilder.expression("true"), + p.project( + Assignments.of(p.symbol("expr_sum"), PlanBuilder.expression("sum + 1"), p.symbol("expr_count"), PlanBuilder.expression("count - 1")), + p.aggregation(outerBuilder -> outerBuilder + .singleGroupingSet(p.symbol("a")) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of()) + .source(p.aggregation(innerBuilder -> innerBuilder + .singleGroupingSet(p.symbol("a")) + .source(p.filter( + PlanBuilder.expression("b = corr"), + p.values(p.symbol("a"), p.symbol("b")))))))))) + .matches( + project(ImmutableMap.of("corr", expression("corr"), "expr_sum", expression("sum_agg + 1"), "expr_count", expression("count_agg - 1")), + aggregation( + singleGroupingSet("corr", "unique", "a"), + ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())), + Optional.empty(), + SINGLE, + join( + JoinNode.Type.INNER, + ImmutableList.of(), + Optional.of("b = corr"), + assignUniqueId( + "unique", + values("corr")), + aggregation( + singleGroupingSet("a", "b"), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + filter( + "true", + values("a", "b"))))))); + } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java index cc80b967b671..33c3c9194b87 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGroupedAggregationWithoutProjection.java @@ -139,4 +139,48 @@ public void rewritesOnSubqueryWithDistinct() "true", values("a", "b"))))))); } + + @Test + public void rewritesOnSubqueryWithDecorrelatableDistinct() + { + // distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator + // because the correlated predicate is equality comparison + tester().assertThat(new TransformCorrelatedGroupedAggregationWithoutProjection(tester().getPlannerContext())) + .on(p -> p.correlatedJoin( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + INNER, + PlanBuilder.expression("true"), + p.aggregation(outerBuilder -> outerBuilder + .singleGroupingSet(p.symbol("a")) + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of()) + .source(p.aggregation(innerBuilder -> innerBuilder + .singleGroupingSet(p.symbol("a")) + .source(p.filter( + PlanBuilder.expression("b = corr"), + p.values(p.symbol("a"), p.symbol("b"))))))))) + .matches( + project(ImmutableMap.of("corr", expression("corr"), "sum_agg", expression("sum_agg"), "count_agg", expression("count_agg")), + aggregation( + singleGroupingSet("corr", "unique", "a"), + ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())), + Optional.empty(), + SINGLE, + join( + JoinNode.Type.INNER, + ImmutableList.of(), + Optional.of("b = corr"), + assignUniqueId( + "unique", + values("corr")), + aggregation( + singleGroupingSet("a", "b"), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + filter( + "true", + values("a", "b"))))))); + } }