From 81bc98b77a0faff5b9b94e983dec14fa8e2bcb2d Mon Sep 17 00:00:00 2001 From: Parth Gandhi Date: Tue, 20 Aug 2024 18:02:18 -0400 Subject: [PATCH] Remove redundant distinct aggregations --- .../io/trino/sql/planner/PlanOptimizers.java | 7 +- .../RemoveRedundantDistinctAggregation.java | 111 ++++++++++++ ...estRemoveRedundantDistinctAggregation.java | 163 ++++++++++++++++++ 3 files changed, 279 insertions(+), 2 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctAggregation.java create mode 100644 core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantDistinctAggregation.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 76327fb3e8f1..d5528f541fb2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -196,6 +196,7 @@ import io.trino.sql.planner.iterative.rule.RemoveEmptyUnionBranches; import io.trino.sql.planner.iterative.rule.RemoveFullSample; import io.trino.sql.planner.iterative.rule.RemoveRedundantDateTrunc; +import io.trino.sql.planner.iterative.rule.RemoveRedundantDistinctAggregation; import io.trino.sql.planner.iterative.rule.RemoveRedundantDistinctLimit; import io.trino.sql.planner.iterative.rule.RemoveRedundantEnforceSingleRowNode; import io.trino.sql.planner.iterative.rule.RemoveRedundantExists; @@ -461,7 +462,8 @@ public PlanOptimizers( new PruneOrderByInAggregation(metadata), new RewriteSpatialPartitioningAggregation(plannerContext), new SimplifyCountOverConstant(plannerContext), - new PreAggregateCaseAggregations(plannerContext))) + new PreAggregateCaseAggregations(plannerContext), + new RemoveRedundantDistinctAggregation())) .build()), // MergeUnion and related projection pruning rules must run before limit pushdown rules, otherwise // an intermediate limit node will prevent unions from being merged later on @@ -596,7 +598,8 @@ public PlanOptimizers( new RemoveEmptyExceptBranches(), new PushFilterIntoValues(plannerContext), // must run after de-correlation new ReplaceJoinOverConstantWithProject(), - new TransformFilteringSemiJoinToInnerJoin())), // must run after PredicatePushDown + new TransformFilteringSemiJoinToInnerJoin(), // must run after PredicatePushDown + new RemoveRedundantDistinctAggregation())), // must also be run after TransformFilteringSemiJoinToInnerJoin new IterativeOptimizer( plannerContext, ruleStats, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctAggregation.java new file mode 100644 index 000000000000..4b51c9973c26 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctAggregation.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Lookup; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; + +import java.util.HashSet; +import java.util.Set; + +import static io.trino.sql.planner.plan.Patterns.aggregation; + +/** + * Removes DISTINCT only aggregation, when the input source is already distinct over a subset of + * the grouping keys as a result of another aggregation. + * + * Given: + *
+ * - Aggregate[keys = [a, max]]
+ *   - Aggregate[keys = [a]]
+ *     max := max(b)
+ * 
+ *

+ * Produces: + *

+ *   - Aggregate[keys = [a]]
+ *     max := max(b)
+ * 
+ */ +public class RemoveRedundantDistinctAggregation + implements Rule +{ + private static final Pattern PATTERN = aggregation() + .matching(RemoveRedundantDistinctAggregation::isDistinctOnlyAggregation); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode aggregationNode, Captures captures, Context context) + { + Lookup lookup = context.getLookup(); + if (isDistinctOverGroupingKeys(lookup.resolve(aggregationNode.getSource()), lookup, new HashSet<>(aggregationNode.getGroupingKeys()))) { + return Result.ofPlanNode(aggregationNode.getSource()); + } + else { + return Result.empty(); + } + } + + private static boolean isDistinctOnlyAggregation(AggregationNode node) + { + return node.producesDistinctRows() && node.getGroupingSetCount() == 1; + } + + private static boolean isDistinctOverGroupingKeys(PlanNode node, Lookup lookup, Set parentSymbols) + { + return switch (node) { + case AggregationNode aggregationNode -> + aggregationNode.getGroupingSets().getGroupingSetCount() == 1 && parentSymbols.containsAll(aggregationNode.getGroupingSets().getGroupingKeys()); + + // Project nodes introduce new symbols for computed expressions, and therefore end up preserving distinctness + // between the distinct aggregation and the child aggregation nodes so long as all child aggregation keys + // remain present (without transformation by the project) in the distinct aggregation grouping keys + case ProjectNode projectNode -> + isDistinctOverGroupingKeys(lookup.resolve(projectNode.getSource()), lookup, translateProjectReferences(projectNode, parentSymbols)); + + // Filter nodes end up preserving distinctness over the input source + case FilterNode filterNode -> + isDistinctOverGroupingKeys(lookup.resolve(filterNode.getSource()), lookup, parentSymbols); + case null, default -> false; + }; + } + + private static Set translateProjectReferences(ProjectNode projectNode, Set groupingKeys) + { + Set translated = new HashSet<>(); + Assignments assignments = projectNode.getAssignments(); + for (Symbol parentSymbol : groupingKeys) { + Expression expression = assignments.get(parentSymbol); + if (expression instanceof Reference reference) { + translated.add(Symbol.from(reference)); + } + } + return translated; + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantDistinctAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantDistinctAggregation.java new file mode 100644 index 000000000000..b4d4c4258c26 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantDistinctAggregation.java @@ -0,0 +1,163 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.function.OperatorType; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.PlanNode; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; + +public class TestRemoveRedundantDistinctAggregation + extends BaseRuleTest +{ + private static final ResolvedFunction ADD_INTEGER = new TestingFunctionResolution().resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); + + @Test + public void testRemoveDistinctAggregationForGroupedAggregation() + { + tester().assertThat(new RemoveRedundantDistinctAggregation()) + .on(TestRemoveRedundantDistinctAggregation::distinctWithGroupBy) + .matches( + aggregation( + singleGroupingSet("value"), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + values("value"))); + } + + @Test + public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireOnGroupingKeyNotPresentInDistinct() + { + tester().assertThat(new RemoveRedundantDistinctAggregation()) + .on(TestRemoveRedundantDistinctAggregation::distinctWithGroupByWithNonMatchingKeySubset) + .doesNotFire(); + } + + @Test + public void testRemoveDistinctAggregationForGroupedAggregationWithProjectAndFilter() + { + tester().assertThat(new RemoveRedundantDistinctAggregation()) + .on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithProjectWithGroupBy) + .matches( + filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), + project(aggregation(singleGroupingSet("value"), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + values("value"))))); + } + + @Test + public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireForNonIdentityProjectionsOnChildGroupingKeys() + { + tester().assertThat(new RemoveRedundantDistinctAggregation()) + .on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys) + .doesNotFire(); + } + + @Test + public void testRemoveDistinctAggregationForGroupedAggregationWithSymbolReference() + { + tester().assertThat(new RemoveRedundantDistinctAggregation()) + .on(TestRemoveRedundantDistinctAggregation::distinctWithSymbolReferenceWithGroupBy) + .matches( + filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), + project(aggregation(singleGroupingSet("value"), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + values("value"))))); + } + + private static PlanNode distinctWithGroupBy(PlanBuilder p) + { + Symbol value = p.symbol("value", INTEGER); + return p.aggregation(b -> b + .singleGroupingSet(value) + .source(p.aggregation(builder -> builder + .singleGroupingSet(value) + .source(p.values(value))))); + } + + private static PlanNode distinctWithGroupByWithNonMatchingKeySubset(PlanBuilder p) + { + Symbol value = p.symbol("value", INTEGER); + Symbol value2 = p.symbol("value2", INTEGER); + return p.aggregation(b -> b + .singleGroupingSet(value) + .source(p.aggregation(builder -> builder + .singleGroupingSet(value, value2) + .source(p.values(value, value2))))); + } + + private static PlanNode distinctWithFilterWithProjectWithGroupBy(PlanBuilder p) + { + Symbol value = p.symbol("value", INTEGER); + return p.aggregation(b -> b + .singleGroupingSet(value) + .source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), + p.project(Assignments.builder().putIdentity(value).build(), + p.aggregation(builder -> builder + .singleGroupingSet(value) + .source(p.values(value))))))); + } + + private static PlanNode distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys(PlanBuilder p) + { + Symbol value = p.symbol("value", INTEGER); + return p.aggregation(b -> b + .singleGroupingSet(value) + .source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), + p.project(Assignments.builder().put(value, new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))).build(), + p.aggregation(builder -> builder + .singleGroupingSet(value) + .source(p.values(value))))))); + } + + private static PlanNode distinctWithSymbolReferenceWithGroupBy(PlanBuilder p) + { + Symbol value = p.symbol("value", INTEGER); + Symbol value2 = p.symbol("value2", INTEGER); + return p.aggregation(b -> b + .singleGroupingSet(value2) + .source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), + p.project(Assignments.builder().put(value2, value.toSymbolReference()).build(), + p.aggregation(builder -> builder + .singleGroupingSet(value) + .source(p.values(value))))))); + } +}