Skip to content

Commit

Permalink
Refactor DistinctAggregationStrategyChooser
Browse files Browse the repository at this point in the history
Make different distinct aggregation strategy choices
exclusive, so that order of optimizer rules does not matter.
  • Loading branch information
lukasz-stec authored and raunaqmorarka committed Jul 26, 2024
1 parent 185e071 commit 9e7c548
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy;
import io.trino.sql.planner.plan.AggregationNode;

import static io.trino.SystemSessionProperties.distinctAggregationsStrategy;
import static io.trino.SystemSessionProperties.getTaskConcurrency;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct;
import static io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations.canUsePreAggregate;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
import static java.util.Objects.requireNonNull;
Expand All @@ -46,43 +54,54 @@ public static DistinctAggregationStrategyChooser createDistinctAggregationStrate

public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider)
{
return !canParallelizeSingleStepDistinctAggregation(aggregationNode, session, statsProvider, MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER);
return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider) == MARK_DISTINCT;
}

public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider)
{
if (canParallelizeSingleStepDistinctAggregation(aggregationNode, session, statsProvider, PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER)) {
return false;
}

// mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2
// because group-by keys are added to every grouping set and this makes partial aggregation behaves badly
return aggregationNode.getGroupingKeys().size() <= 2;
return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider) == PRE_AGGREGATE;
}

private boolean canParallelizeSingleStepDistinctAggregation(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, int maxOutputRowCountMultiplier)
private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Session session, StatsProvider statsProvider)
{
if (aggregationNode.getGroupingKeys().isEmpty()) {
// global distinct aggregation is computed using a single thread. MarkDistinct will help parallelize the execution.
return false;
DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(session);
if (distinctAggregationsStrategy != AUTOMATIC) {
if (distinctAggregationsStrategy == MARK_DISTINCT && canUseMarkDistinct(aggregationNode)) {
return MARK_DISTINCT;
}
if (distinctAggregationsStrategy == PRE_AGGREGATE && canUsePreAggregate(aggregationNode)) {
return PRE_AGGREGATE;
}
// in case strategy is chosen by the session property, but we cannot use it, lets fallback to single-step
return SINGLE_STEP;
}

double numberOfDistinctValues = getMinDistinctValueCountEstimate(aggregationNode, statsProvider);
if (isNaN(numberOfDistinctValues)) {
// if the estimate is unknown, use MarkDistinct to avoid query failure
return false;
int maxNumberOfConcurrentThreadsForAggregation = getMaxNumberOfConcurrentThreadsForAggregation(session);

// use single_step if it can be parallelized
// small numberOfDistinctValues reduces the distinct aggregation parallelism, also because the partitioning may be skewed.
// this makes query to underutilize the cluster CPU but also to possibly concentrate memory on few nodes.
// single_step alternatives should increase the parallelism at a cost of CPU.
if (!aggregationNode.getGroupingKeys().isEmpty() && // global distinct aggregation is computed using a single thread. Strategies other than single_step will help parallelize the execution.
!isNaN(numberOfDistinctValues) && // if the estimate is unknown, use alternatives to avoid query failure
(numberOfDistinctValues > PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * maxNumberOfConcurrentThreadsForAggregation ||
(numberOfDistinctValues > MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * maxNumberOfConcurrentThreadsForAggregation &&
// if the NDV and the number of grouping keys is small, pre-aggregate is faster than single_step at a cost of CPU
aggregationNode.getGroupingKeys().size() > 2))) {
return SINGLE_STEP;
}

int maxNumberOfConcurrentThreadsForAggregation = getMaxNumberOfConcurrentThreadsForAggregation(session);
if (numberOfDistinctValues <= maxOutputRowCountMultiplier * maxNumberOfConcurrentThreadsForAggregation) {
// small numberOfDistinctValues reduces the distinct aggregation parallelism, also because the partitioning may be skewed.
// This makes query to underutilize the cluster CPU but also to possibly concentrate memory on few nodes.
// MarkDistinct should increase the parallelism at a cost of CPU.
return false;
// mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2
// because group-by keys are added to every grouping set and this makes partial aggregation behaves badly
if (canUsePreAggregate(aggregationNode) && aggregationNode.getGroupingKeys().size() <= 2) {
return PRE_AGGREGATE;
}
else if (canUseMarkDistinct(aggregationNode)) {
return MARK_DISTINCT;
}

// can parallelize single-step, and single-step distinct is more efficient than alternatives
return true;
// if no strategy found, use single_step by default
return SINGLE_STEP;
}

private int getMaxNumberOfConcurrentThreadsForAggregation(Session session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
Expand Down Expand Up @@ -68,12 +67,13 @@ public class MultipleDistinctAggregationToMarkDistinct
implements Rule<AggregationNode>
{
private static final Pattern<AggregationNode> PATTERN = aggregation()
.matching(
Predicates.and(
MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask,
Predicates.or(
MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts,
MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));
.matching(MultipleDistinctAggregationToMarkDistinct::canUseMarkDistinct);

public static boolean canUseMarkDistinct(AggregationNode aggregationNode)
{
return hasNoDistinctWithFilterOrMask(aggregationNode) &&
(hasMultipleDistincts(aggregationNode) || hasMixedDistinctAndNonDistincts(aggregationNode));
}

private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregationNode)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -83,16 +82,18 @@ public class OptimizeMixedDistinctAggregations
private static final CatalogSchemaFunctionName APPROX_DISTINCT_NAME = builtinFunctionName("approx_distinct");

private static final Pattern<AggregationNode> PATTERN = aggregation()
.matching(Predicates.and(
Predicates.or(
// single distinct can be supported in this rule, but it is already supported by SingleDistinctAggregationToGroupBy, which produces simpler plans (without group-id)
OptimizeMixedDistinctAggregations::hasMultipleDistincts,
OptimizeMixedDistinctAggregations::hasMixedDistinctAndNonDistincts),
OptimizeMixedDistinctAggregations::allDistinctAggregationsHaveSingleArgument,
OptimizeMixedDistinctAggregations::noFilters,
OptimizeMixedDistinctAggregations::noMasks,
aggregation -> !aggregation.hasOrderings(),
aggregation -> aggregation.getStep().equals(SINGLE)));
.matching(OptimizeMixedDistinctAggregations::canUsePreAggregate);

public static boolean canUsePreAggregate(AggregationNode aggregationNode)
{
// single distinct can be supported in this rule, but it is already supported by SingleDistinctAggregationToGroupBy, which produces simpler plans (without group-id)
return (hasMultipleDistincts(aggregationNode) || hasMixedDistinctAndNonDistincts(aggregationNode)) &&
allDistinctAggregationsHaveSingleArgument(aggregationNode) &&
noFilters(aggregationNode) &&
noMasks(aggregationNode) &&
!aggregationNode.hasOrderings() &&
aggregationNode.getStep().equals(SINGLE);
}

public static boolean hasMultipleDistincts(AggregationNode aggregationNode)
{
Expand Down Expand Up @@ -209,9 +210,9 @@ public Result apply(AggregationNode node, Captures captures, Context context)
Aggregation originalAggregation = entry.getValue();
if (originalAggregation.isDistinct()) {
// for the outer aggregation node, replace distinct aggregation with non-distinct aggregation with FILTER (WHERE group_id=X)
Symbol aggregationInput = Symbol.from(originalAggregation.getArguments().get(0));
Symbol aggregationInput = Symbol.from(originalAggregation.getArguments().getFirst());
Integer groupId = distinctAggregationArgumentToGroupIdMap.get(aggregationInput);
Symbol groupIdFilterSymbol = groupIdFilterSymbolByGroupId.computeIfAbsent(groupId, id -> {
Symbol groupIdFilterSymbol = groupIdFilterSymbolByGroupId.computeIfAbsent(groupId, _ -> {
Symbol filterSymbol = symbolAllocator.newSymbol("gid-filter-" + groupId, BOOLEAN);
groupIdFilters.put(filterSymbol, new Comparison(
EQUAL,
Expand Down
Loading

0 comments on commit 9e7c548

Please sign in to comment.