Skip to content

Commit

Permalink
Remove redundant distinct aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
pgandhi999 committed Sep 10, 2024
1 parent 64e5394 commit 81bc98b
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@
import io.trino.sql.planner.iterative.rule.RemoveEmptyUnionBranches;
import io.trino.sql.planner.iterative.rule.RemoveFullSample;
import io.trino.sql.planner.iterative.rule.RemoveRedundantDateTrunc;
import io.trino.sql.planner.iterative.rule.RemoveRedundantDistinctAggregation;
import io.trino.sql.planner.iterative.rule.RemoveRedundantDistinctLimit;
import io.trino.sql.planner.iterative.rule.RemoveRedundantEnforceSingleRowNode;
import io.trino.sql.planner.iterative.rule.RemoveRedundantExists;
Expand Down Expand Up @@ -461,7 +462,8 @@ public PlanOptimizers(
new PruneOrderByInAggregation(metadata),
new RewriteSpatialPartitioningAggregation(plannerContext),
new SimplifyCountOverConstant(plannerContext),
new PreAggregateCaseAggregations(plannerContext)))
new PreAggregateCaseAggregations(plannerContext),
new RemoveRedundantDistinctAggregation()))
.build()),
// MergeUnion and related projection pruning rules must run before limit pushdown rules, otherwise
// an intermediate limit node will prevent unions from being merged later on
Expand Down Expand Up @@ -596,7 +598,8 @@ public PlanOptimizers(
new RemoveEmptyExceptBranches(),
new PushFilterIntoValues(plannerContext), // must run after de-correlation
new ReplaceJoinOverConstantWithProject(),
new TransformFilteringSemiJoinToInnerJoin())), // must run after PredicatePushDown
new TransformFilteringSemiJoinToInnerJoin(), // must run after PredicatePushDown
new RemoveRedundantDistinctAggregation())), // must also be run after TransformFilteringSemiJoinToInnerJoin
new IterativeOptimizer(
plannerContext,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;

import java.util.HashSet;
import java.util.Set;

import static io.trino.sql.planner.plan.Patterns.aggregation;

/**
* Removes DISTINCT only aggregation, when the input source is already distinct over a subset of
* the grouping keys as a result of another aggregation.
*
* Given:
* <pre>
* - Aggregate[keys = [a, max]]
* - Aggregate[keys = [a]]
* max := max(b)
* </pre>
* <p>
* Produces:
* <pre>
* - Aggregate[keys = [a]]
* max := max(b)
* </pre>
*/
public class RemoveRedundantDistinctAggregation
implements Rule<AggregationNode>
{
private static final Pattern<AggregationNode> PATTERN = aggregation()
.matching(RemoveRedundantDistinctAggregation::isDistinctOnlyAggregation);

@Override
public Pattern<AggregationNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
{
Lookup lookup = context.getLookup();
if (isDistinctOverGroupingKeys(lookup.resolve(aggregationNode.getSource()), lookup, new HashSet<>(aggregationNode.getGroupingKeys()))) {
return Result.ofPlanNode(aggregationNode.getSource());
}
else {
return Result.empty();
}
}

private static boolean isDistinctOnlyAggregation(AggregationNode node)
{
return node.producesDistinctRows() && node.getGroupingSetCount() == 1;
}

private static boolean isDistinctOverGroupingKeys(PlanNode node, Lookup lookup, Set<Symbol> parentSymbols)
{
return switch (node) {
case AggregationNode aggregationNode ->
aggregationNode.getGroupingSets().getGroupingSetCount() == 1 && parentSymbols.containsAll(aggregationNode.getGroupingSets().getGroupingKeys());

// Project nodes introduce new symbols for computed expressions, and therefore end up preserving distinctness
// between the distinct aggregation and the child aggregation nodes so long as all child aggregation keys
// remain present (without transformation by the project) in the distinct aggregation grouping keys
case ProjectNode projectNode ->
isDistinctOverGroupingKeys(lookup.resolve(projectNode.getSource()), lookup, translateProjectReferences(projectNode, parentSymbols));

// Filter nodes end up preserving distinctness over the input source
case FilterNode filterNode ->
isDistinctOverGroupingKeys(lookup.resolve(filterNode.getSource()), lookup, parentSymbols);
case null, default -> false;
};
}

private static Set<Symbol> translateProjectReferences(ProjectNode projectNode, Set<Symbol> groupingKeys)
{
Set<Symbol> translated = new HashSet<>();
Assignments assignments = projectNode.getAssignments();
for (Symbol parentSymbol : groupingKeys) {
Expression expression = assignments.get(parentSymbol);
if (expression instanceof Reference reference) {
translated.add(Symbol.from(reference));
}
}
return translated;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.function.OperatorType;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import org.junit.jupiter.api.Test;

import java.util.Optional;

import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN;
import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation;
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter;
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;

public class TestRemoveRedundantDistinctAggregation
extends BaseRuleTest
{
private static final ResolvedFunction ADD_INTEGER = new TestingFunctionResolution().resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER));

@Test
public void testRemoveDistinctAggregationForGroupedAggregation()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithGroupBy)
.matches(
aggregation(
singleGroupingSet("value"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
values("value")));
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireOnGroupingKeyNotPresentInDistinct()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithGroupByWithNonMatchingKeySubset)
.doesNotFire();
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationWithProjectAndFilter()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithProjectWithGroupBy)
.matches(
filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
project(aggregation(singleGroupingSet("value"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
values("value")))));
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireForNonIdentityProjectionsOnChildGroupingKeys()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys)
.doesNotFire();
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationWithSymbolReference()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithSymbolReferenceWithGroupBy)
.matches(
filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
project(aggregation(singleGroupingSet("value"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
values("value")))));
}

private static PlanNode distinctWithGroupBy(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))));
}

private static PlanNode distinctWithGroupByWithNonMatchingKeySubset(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
Symbol value2 = p.symbol("value2", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.aggregation(builder -> builder
.singleGroupingSet(value, value2)
.source(p.values(value, value2)))));
}

private static PlanNode distinctWithFilterWithProjectWithGroupBy(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
p.project(Assignments.builder().putIdentity(value).build(),
p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))))));
}

private static PlanNode distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
p.project(Assignments.builder().put(value, new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))).build(),
p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))))));
}

private static PlanNode distinctWithSymbolReferenceWithGroupBy(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
Symbol value2 = p.symbol("value2", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value2)
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
p.project(Assignments.builder().put(value2, value.toSymbolReference()).build(),
p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))))));
}
}

0 comments on commit 81bc98b

Please sign in to comment.