-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove redundant distinct aggregations
- Loading branch information
1 parent
64e5394
commit 81bc98b
Showing
3 changed files
with
279 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
...src/main/java/io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctAggregation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.sql.planner.iterative.rule; | ||
|
||
import io.trino.matching.Captures; | ||
import io.trino.matching.Pattern; | ||
import io.trino.sql.ir.Expression; | ||
import io.trino.sql.ir.Reference; | ||
import io.trino.sql.planner.Symbol; | ||
import io.trino.sql.planner.iterative.Lookup; | ||
import io.trino.sql.planner.iterative.Rule; | ||
import io.trino.sql.planner.plan.AggregationNode; | ||
import io.trino.sql.planner.plan.Assignments; | ||
import io.trino.sql.planner.plan.FilterNode; | ||
import io.trino.sql.planner.plan.PlanNode; | ||
import io.trino.sql.planner.plan.ProjectNode; | ||
|
||
import java.util.HashSet; | ||
import java.util.Set; | ||
|
||
import static io.trino.sql.planner.plan.Patterns.aggregation; | ||
|
||
/** | ||
* Removes DISTINCT only aggregation, when the input source is already distinct over a subset of | ||
* the grouping keys as a result of another aggregation. | ||
* | ||
* Given: | ||
* <pre> | ||
* - Aggregate[keys = [a, max]] | ||
* - Aggregate[keys = [a]] | ||
* max := max(b) | ||
* </pre> | ||
* <p> | ||
* Produces: | ||
* <pre> | ||
* - Aggregate[keys = [a]] | ||
* max := max(b) | ||
* </pre> | ||
*/ | ||
public class RemoveRedundantDistinctAggregation | ||
implements Rule<AggregationNode> | ||
{ | ||
private static final Pattern<AggregationNode> PATTERN = aggregation() | ||
.matching(RemoveRedundantDistinctAggregation::isDistinctOnlyAggregation); | ||
|
||
@Override | ||
public Pattern<AggregationNode> getPattern() | ||
{ | ||
return PATTERN; | ||
} | ||
|
||
@Override | ||
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) | ||
{ | ||
Lookup lookup = context.getLookup(); | ||
if (isDistinctOverGroupingKeys(lookup.resolve(aggregationNode.getSource()), lookup, new HashSet<>(aggregationNode.getGroupingKeys()))) { | ||
return Result.ofPlanNode(aggregationNode.getSource()); | ||
} | ||
else { | ||
return Result.empty(); | ||
} | ||
} | ||
|
||
private static boolean isDistinctOnlyAggregation(AggregationNode node) | ||
{ | ||
return node.producesDistinctRows() && node.getGroupingSetCount() == 1; | ||
} | ||
|
||
private static boolean isDistinctOverGroupingKeys(PlanNode node, Lookup lookup, Set<Symbol> parentSymbols) | ||
{ | ||
return switch (node) { | ||
case AggregationNode aggregationNode -> | ||
aggregationNode.getGroupingSets().getGroupingSetCount() == 1 && parentSymbols.containsAll(aggregationNode.getGroupingSets().getGroupingKeys()); | ||
|
||
// Project nodes introduce new symbols for computed expressions, and therefore end up preserving distinctness | ||
// between the distinct aggregation and the child aggregation nodes so long as all child aggregation keys | ||
// remain present (without transformation by the project) in the distinct aggregation grouping keys | ||
case ProjectNode projectNode -> | ||
isDistinctOverGroupingKeys(lookup.resolve(projectNode.getSource()), lookup, translateProjectReferences(projectNode, parentSymbols)); | ||
|
||
// Filter nodes end up preserving distinctness over the input source | ||
case FilterNode filterNode -> | ||
isDistinctOverGroupingKeys(lookup.resolve(filterNode.getSource()), lookup, parentSymbols); | ||
case null, default -> false; | ||
}; | ||
} | ||
|
||
private static Set<Symbol> translateProjectReferences(ProjectNode projectNode, Set<Symbol> groupingKeys) | ||
{ | ||
Set<Symbol> translated = new HashSet<>(); | ||
Assignments assignments = projectNode.getAssignments(); | ||
for (Symbol parentSymbol : groupingKeys) { | ||
Expression expression = assignments.get(parentSymbol); | ||
if (expression instanceof Reference reference) { | ||
translated.add(Symbol.from(reference)); | ||
} | ||
} | ||
return translated; | ||
} | ||
} |
163 changes: 163 additions & 0 deletions
163
...test/java/io/trino/sql/planner/iterative/rule/TestRemoveRedundantDistinctAggregation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.trino.sql.planner.iterative.rule; | ||
|
||
import com.google.common.collect.ImmutableList; | ||
import com.google.common.collect.ImmutableMap; | ||
import io.trino.metadata.ResolvedFunction; | ||
import io.trino.metadata.TestingFunctionResolution; | ||
import io.trino.spi.function.OperatorType; | ||
import io.trino.sql.ir.Call; | ||
import io.trino.sql.ir.Comparison; | ||
import io.trino.sql.ir.Constant; | ||
import io.trino.sql.ir.Reference; | ||
import io.trino.sql.planner.Symbol; | ||
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; | ||
import io.trino.sql.planner.iterative.rule.test.PlanBuilder; | ||
import io.trino.sql.planner.plan.Assignments; | ||
import io.trino.sql.planner.plan.PlanNode; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.Optional; | ||
|
||
import static io.trino.spi.type.IntegerType.INTEGER; | ||
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; | ||
import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; | ||
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; | ||
import static io.trino.sql.planner.assertions.PlanMatchPattern.project; | ||
import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; | ||
import static io.trino.sql.planner.assertions.PlanMatchPattern.values; | ||
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; | ||
|
||
public class TestRemoveRedundantDistinctAggregation | ||
extends BaseRuleTest | ||
{ | ||
private static final ResolvedFunction ADD_INTEGER = new TestingFunctionResolution().resolveOperator(OperatorType.ADD, ImmutableList.of(INTEGER, INTEGER)); | ||
|
||
@Test | ||
public void testRemoveDistinctAggregationForGroupedAggregation() | ||
{ | ||
tester().assertThat(new RemoveRedundantDistinctAggregation()) | ||
.on(TestRemoveRedundantDistinctAggregation::distinctWithGroupBy) | ||
.matches( | ||
aggregation( | ||
singleGroupingSet("value"), | ||
ImmutableMap.of(), | ||
Optional.empty(), | ||
SINGLE, | ||
values("value"))); | ||
} | ||
|
||
@Test | ||
public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireOnGroupingKeyNotPresentInDistinct() | ||
{ | ||
tester().assertThat(new RemoveRedundantDistinctAggregation()) | ||
.on(TestRemoveRedundantDistinctAggregation::distinctWithGroupByWithNonMatchingKeySubset) | ||
.doesNotFire(); | ||
} | ||
|
||
@Test | ||
public void testRemoveDistinctAggregationForGroupedAggregationWithProjectAndFilter() | ||
{ | ||
tester().assertThat(new RemoveRedundantDistinctAggregation()) | ||
.on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithProjectWithGroupBy) | ||
.matches( | ||
filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), | ||
project(aggregation(singleGroupingSet("value"), | ||
ImmutableMap.of(), | ||
Optional.empty(), | ||
SINGLE, | ||
values("value"))))); | ||
} | ||
|
||
@Test | ||
public void testRemoveDistinctAggregationForGroupedAggregationDoesNotFireForNonIdentityProjectionsOnChildGroupingKeys() | ||
{ | ||
tester().assertThat(new RemoveRedundantDistinctAggregation()) | ||
.on(TestRemoveRedundantDistinctAggregation::distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys) | ||
.doesNotFire(); | ||
} | ||
|
||
@Test | ||
public void testRemoveDistinctAggregationForGroupedAggregationWithSymbolReference() | ||
{ | ||
tester().assertThat(new RemoveRedundantDistinctAggregation()) | ||
.on(TestRemoveRedundantDistinctAggregation::distinctWithSymbolReferenceWithGroupBy) | ||
.matches( | ||
filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), | ||
project(aggregation(singleGroupingSet("value"), | ||
ImmutableMap.of(), | ||
Optional.empty(), | ||
SINGLE, | ||
values("value"))))); | ||
} | ||
|
||
private static PlanNode distinctWithGroupBy(PlanBuilder p) | ||
{ | ||
Symbol value = p.symbol("value", INTEGER); | ||
return p.aggregation(b -> b | ||
.singleGroupingSet(value) | ||
.source(p.aggregation(builder -> builder | ||
.singleGroupingSet(value) | ||
.source(p.values(value))))); | ||
} | ||
|
||
private static PlanNode distinctWithGroupByWithNonMatchingKeySubset(PlanBuilder p) | ||
{ | ||
Symbol value = p.symbol("value", INTEGER); | ||
Symbol value2 = p.symbol("value2", INTEGER); | ||
return p.aggregation(b -> b | ||
.singleGroupingSet(value) | ||
.source(p.aggregation(builder -> builder | ||
.singleGroupingSet(value, value2) | ||
.source(p.values(value, value2))))); | ||
} | ||
|
||
private static PlanNode distinctWithFilterWithProjectWithGroupBy(PlanBuilder p) | ||
{ | ||
Symbol value = p.symbol("value", INTEGER); | ||
return p.aggregation(b -> b | ||
.singleGroupingSet(value) | ||
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), | ||
p.project(Assignments.builder().putIdentity(value).build(), | ||
p.aggregation(builder -> builder | ||
.singleGroupingSet(value) | ||
.source(p.values(value))))))); | ||
} | ||
|
||
private static PlanNode distinctWithFilterWithNonIdentityProjectOnChildGroupingKeys(PlanBuilder p) | ||
{ | ||
Symbol value = p.symbol("value", INTEGER); | ||
return p.aggregation(b -> b | ||
.singleGroupingSet(value) | ||
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), | ||
p.project(Assignments.builder().put(value, new Call(ADD_INTEGER, ImmutableList.of(new Reference(INTEGER, "x"), new Constant(INTEGER, 2L)))).build(), | ||
p.aggregation(builder -> builder | ||
.singleGroupingSet(value) | ||
.source(p.values(value))))))); | ||
} | ||
|
||
private static PlanNode distinctWithSymbolReferenceWithGroupBy(PlanBuilder p) | ||
{ | ||
Symbol value = p.symbol("value", INTEGER); | ||
Symbol value2 = p.symbol("value2", INTEGER); | ||
return p.aggregation(b -> b | ||
.singleGroupingSet(value2) | ||
.source(p.filter(new Comparison(GREATER_THAN, new Constant(INTEGER, 6L), new Constant(INTEGER, 5L)), | ||
p.project(Assignments.builder().put(value2, value.toSymbolReference()).build(), | ||
p.aggregation(builder -> builder | ||
.singleGroupingSet(value) | ||
.source(p.values(value))))))); | ||
} | ||
} |