Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new Optimizer Rule to Remove Distinct Aggregation for Queries containing Grouped Aggregation Operators #23087

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@
import io.trino.sql.planner.iterative.rule.RemoveEmptyUnionBranches;
import io.trino.sql.planner.iterative.rule.RemoveFullSample;
import io.trino.sql.planner.iterative.rule.RemoveRedundantDateTrunc;
import io.trino.sql.planner.iterative.rule.RemoveRedundantDistinctAggregation;
import io.trino.sql.planner.iterative.rule.RemoveRedundantDistinctLimit;
import io.trino.sql.planner.iterative.rule.RemoveRedundantEnforceSingleRowNode;
import io.trino.sql.planner.iterative.rule.RemoveRedundantExists;
Expand Down Expand Up @@ -461,7 +462,8 @@ public PlanOptimizers(
new PruneOrderByInAggregation(metadata),
new RewriteSpatialPartitioningAggregation(plannerContext),
new SimplifyCountOverConstant(plannerContext),
new PreAggregateCaseAggregations(plannerContext)))
new PreAggregateCaseAggregations(plannerContext),
new RemoveRedundantDistinctAggregation()))
.build()),
// MergeUnion and related projection pruning rules must run before limit pushdown rules, otherwise
// an intermediate limit node will prevent unions from being merged later on
Expand Down Expand Up @@ -596,7 +598,8 @@ public PlanOptimizers(
new RemoveEmptyExceptBranches(),
new PushFilterIntoValues(plannerContext), // must run after de-correlation
new ReplaceJoinOverConstantWithProject(),
new TransformFilteringSemiJoinToInnerJoin())), // must run after PredicatePushDown
new TransformFilteringSemiJoinToInnerJoin(), // must run after PredicatePushDown
new RemoveRedundantDistinctAggregation())), // must also be run after TransformFilteringSemiJoinToInnerJoin
new IterativeOptimizer(
plannerContext,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;

import java.util.HashSet;
import java.util.Set;

import static io.trino.sql.planner.plan.Patterns.aggregation;

/**
* Removes DISTINCT only aggregation, when the input source is already distinct over a subset of
* the grouping keys as a result of another aggregation.
*
* Given:
* <pre>
* - Aggregate[keys = [a, max]]
* - Aggregate[keys = [a]]
* max := max(b)
* </pre>
* <p>
* Produces:
* <pre>
* - Aggregate[keys = [a]]
* max := max(b)
* </pre>
*/
public class RemoveRedundantDistinctAggregation
pgandhi999 marked this conversation as resolved.
Show resolved Hide resolved
implements Rule<AggregationNode>
{
private static final Pattern<AggregationNode> PATTERN = aggregation()
.matching(RemoveRedundantDistinctAggregation::isDistinctOnlyAggregation);

@Override
public Pattern<AggregationNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
{
Lookup lookup = context.getLookup();
if (isDistinctOverGroupingKeys(lookup.resolve(aggregationNode.getSource()), lookup, new HashSet<>(aggregationNode.getGroupingKeys()))) {
return Result.ofPlanNode(aggregationNode.getSource());
}
else {
return Result.empty();
}
}

private static boolean isDistinctOnlyAggregation(AggregationNode node)
{
return node.producesDistinctRows() && node.getGroupingSetCount() == 1;
}

private static boolean isDistinctOverGroupingKeys(PlanNode node, Lookup lookup, Set<Symbol> parentSymbols)
{
return switch (node) {
case AggregationNode aggregationNode ->
aggregationNode.getGroupingSets().getGroupingSetCount() == 1 && parentSymbols.containsAll(aggregationNode.getGroupingSets().getGroupingKeys());

// Project nodes introduce new symbols for computed expressions, and therefore end up preserving distinctness
// between the distinct aggregation and the child aggregation nodes so long as all child aggregation keys
// remain present (without transformation by the project) in the distinct aggregation grouping keys
case ProjectNode projectNode ->
isDistinctOverGroupingKeys(lookup.resolve(projectNode.getSource()), lookup, translateProjectReferences(projectNode, parentSymbols));

// Filter nodes end up preserving distinctness over the input source
case FilterNode filterNode ->
isDistinctOverGroupingKeys(lookup.resolve(filterNode.getSource()), lookup, parentSymbols);
case null, default -> false;
};
}

private static Set<Symbol> translateProjectReferences(ProjectNode projectNode, Set<Symbol> groupingKeys)
{
Set<Symbol> translated = new HashSet<>();
Assignments assignments = projectNode.getAssignments();
for (Symbol parentSymbol : groupingKeys) {
Expression expression = assignments.get(parentSymbol);
if (expression instanceof Reference reference) {
translated.add(Symbol.from(reference));
}
}
return translated;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.function.OperatorType;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import org.junit.jupiter.api.Test;

import java.util.Optional;

import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN;
import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation;
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter;
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;

public class TestRemoveRedundantDistinctAggregation
extends BaseRuleTest
{
private static final ResolvedFunction ADD_INTEGER = new TestingFunctionResolution().resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER));

@Test
public void testRemoveDistinctAggregationForGroupedAggregation()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithGroupBy)
.matches(
aggregation(
singleGroupingSet("value"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
values("value")));
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireOnGroupingKeyNotPresentInDistinct()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithGroupByWithNonMatchingKeySubset)
.doesNotFire();
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationWithProjectAndFilter()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithProjectWithGroupBy)
.matches(
filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
project(aggregation(singleGroupingSet("value"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
values("value")))));
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireForNonIdentityProjectionsOnChildGroupingKeys()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys)
.doesNotFire();
}

@Test
public void testRemoveDistinctAggregationForGroupedAggregationWithSymbolReference()
{
tester().assertThat(new RemoveRedundantDistinctAggregation())
.on(TestRemoveRedundantDistinctAggregation::distinctWithSymbolReferenceWithGroupBy)
.matches(
filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
project(aggregation(singleGroupingSet("value"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
values("value")))));
}

private static PlanNode distinctWithGroupBy(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))));
}

private static PlanNode distinctWithGroupByWithNonMatchingKeySubset(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
Symbol value2 = p.symbol("value2", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.aggregation(builder -> builder
.singleGroupingSet(value, value2)
.source(p.values(value, value2)))));
}

private static PlanNode distinctWithFilterWithProjectWithGroupBy(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
p.project(Assignments.builder().putIdentity(value).build(),
p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))))));
}

private static PlanNode distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value)
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
p.project(Assignments.builder().put(value, new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))).build(),
p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))))));
}

private static PlanNode distinctWithSymbolReferenceWithGroupBy(PlanBuilder p)
{
Symbol value = p.symbol("value", INTEGER);
Symbol value2 = p.symbol("value2", INTEGER);
return p.aggregation(b -> b
.singleGroupingSet(value2)
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)),
p.project(Assignments.builder().put(value2, value.toSymbolReference()).build(),
p.aggregation(builder -> builder
.singleGroupingSet(value)
.source(p.values(value)))))));
}
}
Loading