Skip to content

Commit

Permalink
Enable MultipleDistinctAggregationsToSubqueries
Browse files Browse the repository at this point in the history
Make MultipleDistinctAggregationsToSubqueries to fire when
distinct_aggregations_strategy=AUTOMATIC and we can be
confident based on stats that the rule will be beneficial.
Aggregation source is limited to table scan, filter,
and project.
  • Loading branch information
lukasz-stec authored and raunaqmorarka committed Jul 26, 2024
1 parent 9e7c548 commit eee8d4d
Show file tree
Hide file tree
Showing 10 changed files with 884 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ public PlanOptimizers(
new RemoveRedundantExists(),
new RemoveRedundantWindow(),
new ImplementFilteredAggregations(),
new MultipleDistinctAggregationsToSubqueries(metadata),
new SingleDistinctAggregationToGroupBy(),
new MergeLimitWithDistinct(),
new PruneCountAggregationOverScalar(metadata),
Expand Down Expand Up @@ -684,10 +683,15 @@ public PlanOptimizers(
new RemoveRedundantIdentityProjections(),
new PushAggregationThroughOuterJoin(),
new ReplaceRedundantJoinWithSource(), // Run this after PredicatePushDown optimizer as it inlines filter constants
// Run this after PredicatePushDown and PushProjectionIntoTableScan as it uses stats, and those two rules may reduce the number of partitions
// and columns we need stats for thus reducing the overhead of reading statistics from the metastore.
new MultipleDistinctAggregationsToSubqueries(taskCountEstimator, metadata),
// Run SingleDistinctAggregationToGroupBy after MultipleDistinctAggregationsToSubqueries to ensure the single column distinct is optimized
new SingleDistinctAggregationToGroupBy(),
new OptimizeMixedDistinctAggregations(plannerContext, taskCountEstimator), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector
// It also is run before MultipleDistinctAggregationToMarkDistinct to take precedence if enabled
// It also is run before MultipleDistinctAggregationToMarkDistinct to take precedence f enabled
new ImplementFilteredAggregations(), // DistinctAggregationToGroupBy will add filters if fired
new MultipleDistinctAggregationToMarkDistinct(taskCountEstimator))), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector
new MultipleDistinctAggregationToMarkDistinct(taskCountEstimator, metadata))), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector
inlineProjections,
simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations
pushProjectionIntoTableScanOptimizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,40 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;

import java.util.List;
import java.util.Set;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.distinctAggregationsStrategy;
import static io.trino.SystemSessionProperties.getTaskConcurrency;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES;
import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct;
import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries;
import static io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations.canUsePreAggregate;
import static io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations.distinctAggregationsUniqueArgumentCount;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
import static java.util.Objects.requireNonNull;
Expand All @@ -39,30 +58,38 @@ public class DistinctAggregationStrategyChooser
{
private static final int MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 8;
private static final int PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * 8;
private static final double MAX_JOIN_GROUPING_KEYS_SIZE = 100 * 1024 * 1024; // 100 MB

private final TaskCountEstimator taskCountEstimator;
private final Metadata metadata;

private DistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator)
public DistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata)
{
this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
}

public static DistinctAggregationStrategyChooser createDistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator)
public static DistinctAggregationStrategyChooser createDistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata)
{
return new DistinctAggregationStrategyChooser(taskCountEstimator);
return new DistinctAggregationStrategyChooser(taskCountEstimator, metadata);
}

public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider)
public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup)
{
return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider) == MARK_DISTINCT;
return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == MARK_DISTINCT;
}

public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider)
public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup)
{
return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider) == PRE_AGGREGATE;
return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == PRE_AGGREGATE;
}

private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Session session, StatsProvider statsProvider)
public boolean shouldSplitToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup)
{
return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == SPLIT_TO_SUBQUERIES;
}

private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup)
{
DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(session);
if (distinctAggregationsStrategy != AUTOMATIC) {
Expand All @@ -72,6 +99,9 @@ private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode
if (distinctAggregationsStrategy == PRE_AGGREGATE && canUsePreAggregate(aggregationNode)) {
return PRE_AGGREGATE;
}
if (distinctAggregationsStrategy == SPLIT_TO_SUBQUERIES && isAggregationCandidateForSplittingToSubqueries(aggregationNode) && isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup)) {
return SPLIT_TO_SUBQUERIES;
}
// in case strategy is chosen by the session property, but we cannot use it, lets fallback to single-step
return SINGLE_STEP;
}
Expand All @@ -91,6 +121,12 @@ private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode
return SINGLE_STEP;
}

if (isAggregationCandidateForSplittingToSubqueries(aggregationNode) && shouldSplitAggregationToSubqueries(aggregationNode, session, statsProvider, lookup)) {
// for simple distinct aggregations on top of table scan it makes sense to split the aggregation into multiple subqueries,
// so they can be handled by the SingleDistinctAggregationToGroupBy and use other single column optimizations
return SPLIT_TO_SUBQUERIES;
}

// mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2
// because group-by keys are added to every grouping set and this makes partial aggregation behaves badly
if (canUsePreAggregate(aggregationNode) && aggregationNode.getGroupingKeys().size() <= 2) {
Expand Down Expand Up @@ -121,4 +157,103 @@ private double getMinDistinctValueCountEstimate(AggregationNode aggregationNode,
.map(symbol -> sourceStats.getSymbolStatistics(symbol).getDistinctValuesCount())
.max(Double::compareTo).orElse(NaN);
}

// Since, to avoid degradation caused by multiple table scans, we want to split to sub-queries only if we are confident
// it brings big benefits, we are fairly conservative in the decision below.
private boolean shouldSplitAggregationToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup)
{
if (!isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup)) {
// only table scan, union, filter and project are supported
return false;
}

if (searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(UnionNode.class).findFirst().isPresent()) {
// supporting union with auto decision is complex
return false;
}

// skip if the source has a filter with low selectivity, as the scan and filter can
// be the main bottleneck in this case, and we want to avoid duplicating this effort.
if (searchFrom(aggregationNode.getSource(), lookup)
.where(node -> node instanceof FilterNode filterNode && isSelective(filterNode, statsProvider))
.matches()) {
return false;
}

if (isAdditionalReadOverheadTooExpensive(aggregationNode, statsProvider, lookup)) {
return false;
}

if (aggregationNode.hasSingleGlobalAggregation()) {
return true;
}

PlanNodeStatsEstimate stats = statsProvider.getStats(aggregationNode);
double groupingKeysSizeInBytes = stats.getOutputSizeInBytes(aggregationNode.getGroupingKeys());

// estimated group by result size is big so that both calculating aggregation multiple times and join would be inefficient
return !(isNaN(groupingKeysSizeInBytes) || groupingKeysSizeInBytes > MAX_JOIN_GROUPING_KEYS_SIZE);
}

private static boolean isAdditionalReadOverheadTooExpensive(AggregationNode aggregationNode, StatsProvider statsProvider, Lookup lookup)
{
Set<Symbol> distinctInputs = aggregationNode.getAggregations()
.values().stream()
.filter(AggregationNode.Aggregation::isDistinct)
.flatMap(aggregation -> aggregation.getArguments().stream())
.filter(Reference.class::isInstance)
.map(Symbol::from)
.collect(toImmutableSet());

TableScanNode tableScanNode = (TableScanNode) searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(TableScanNode.class).findOnlyElement();
Set<Symbol> additionalColumns = Sets.difference(ImmutableSet.copyOf(tableScanNode.getOutputSymbols()), distinctInputs);

// Group by columns need to read N times, where N is number of sub-queries.
// Distinct columns are read once.
double singleTableScanDataSize = statsProvider.getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols());
double additionalColumnsDataSize = statsProvider.getStats(tableScanNode).getOutputSizeInBytes(additionalColumns);
long subqueryCount = distinctAggregationsUniqueArgumentCount(aggregationNode);
double distinctInputDataSize = singleTableScanDataSize - additionalColumnsDataSize;
double subqueriesTotalDataSize = additionalColumnsDataSize * subqueryCount + distinctInputDataSize;

return isNaN(subqueriesTotalDataSize) ||
isNaN(singleTableScanDataSize) ||
// we would read more than 50% more data
subqueriesTotalDataSize / singleTableScanDataSize > 1.5;
}

private static boolean isSelective(FilterNode filterNode, StatsProvider statsProvider)
{
double filterOutputRowCount = statsProvider.getStats(filterNode).getOutputRowCount();
double filterSourceRowCount = statsProvider.getStats(filterNode.getSource()).getOutputRowCount();
return filterOutputRowCount / filterSourceRowCount < 0.5;
}

// Only table scan, union, filter and project are supported.
// PlanCopier.copyPlan must support all supported nodes here.
// Additionally, we should split the table scan only if reading single columns is efficient in the given connector.
private boolean isAggregationSourceSupportedForSubqueries(PlanNode source, Session session, Lookup lookup)
{
if (searchFrom(source, lookup)
.where(node -> !(node instanceof TableScanNode
|| node instanceof FilterNode
|| node instanceof ProjectNode
|| node instanceof UnionNode))
.findFirst()
.isPresent()) {
return false;
}

List<PlanNode> tableScans = searchFrom(source, lookup)
.whereIsInstanceOfAny(TableScanNode.class)
.findAll();

if (tableScans.isEmpty()) {
// at least one table scan is expected
return false;
}

return tableScans.stream()
.allMatch(tableScanNode -> metadata.allowSplittingReadIntoMultipleSubQueries(session, ((TableScanNode) tableScanNode).getTable()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
Expand Down Expand Up @@ -105,9 +106,9 @@ private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregati

private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser;

public MultipleDistinctAggregationToMarkDistinct(TaskCountEstimator taskCountEstimator)
public MultipleDistinctAggregationToMarkDistinct(TaskCountEstimator taskCountEstimator, Metadata metadata)
{
this.distinctAggregationStrategyChooser = createDistinctAggregationStrategyChooser(taskCountEstimator);
this.distinctAggregationStrategyChooser = createDistinctAggregationStrategyChooser(taskCountEstimator, metadata);
}

@Override
Expand All @@ -121,7 +122,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context)
{
DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(context.getSession());
if (!(distinctAggregationsStrategy.equals(MARK_DISTINCT) ||
(distinctAggregationsStrategy.equals(AUTOMATIC) && distinctAggregationStrategyChooser.shouldAddMarkDistinct(parent, context.getSession(), context.getStatsProvider())))) {
(distinctAggregationsStrategy.equals(AUTOMATIC) && distinctAggregationStrategyChooser.shouldAddMarkDistinct(parent, context.getSession(), context.getStatsProvider(), context.getLookup())))) {
return Result.empty();
}

Expand Down
Loading

0 comments on commit eee8d4d

Please sign in to comment.