Skip to content

Commit

Permalink
Show stats in EXPLAIN/EXPLAIN ANALYZE with partial aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
Dith3r authored and raunaqmorarka committed Mar 2, 2023
1 parent f7f390e commit a22f24c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.cost;

import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
Expand All @@ -25,8 +26,8 @@
import java.util.Map;
import java.util.Optional;

import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static io.trino.sql.planner.plan.AggregationNode.Step.INTERMEDIATE;
import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static java.lang.Math.min;
import static java.util.Objects.requireNonNull;
Expand All @@ -50,36 +51,35 @@ public Pattern<AggregationNode> getPattern()
@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
if (node.getGroupingSetCount() != 1) {
if (node.getGroupingSetCount() != 1 || node.getStep() == INTERMEDIATE) {
return Optional.empty();
}

if (node.getStep() != SINGLE && node.getStep() != FINAL) {
return Optional.empty();
}
PlanNodeStatsEstimate estimate;

return Optional.of(groupBy(
statsProvider.getStats(node.getSource()),
node.getGroupingKeys(),
node.getAggregations()));
if (node.getStep() == PARTIAL) {
estimate = partialGroupBy(statsProvider.getStats(node.getSource()),
node.getGroupingKeys(),
node.getAggregations());
}
else {
estimate = groupBy(
statsProvider.getStats(node.getSource()),
node.getGroupingKeys(),
node.getAggregations());
}
return Optional.of(estimate);
}

public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, Aggregation> aggregations)
{
// Used to estimate FINAL or SINGLE step aggregations
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
if (groupBySymbols.isEmpty()) {
result.setOutputRowCount(1);
}
else {
for (Symbol groupBySymbol : groupBySymbols) {
SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> {
if (nullsFraction == 0.0) {
return 0.0;
}
return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
}));
}
result.addSymbolStatistics(getGroupBySymbolsStatistics(sourceStats, groupBySymbols));
double rowsCount = getRowsCount(sourceStats, groupBySymbols);
result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount()));
}
Expand All @@ -101,6 +101,35 @@ public static double getRowsCount(PlanNodeStatsEstimate sourceStats, Collection<
return rowsCount;
}

private static PlanNodeStatsEstimate partialGroupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, Aggregation> aggregations)
{
// Pessimistic assumption of no reduction from PARTIAL aggregation, forwarding of the source statistics. This makes the CBO estimates in the EXPLAIN plan output easier to understand,
// even though partial aggregations are added after the CBO rules have been run.
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
result.setOutputRowCount(sourceStats.getOutputRowCount());
result.addSymbolStatistics(getGroupBySymbolsStatistics(sourceStats, groupBySymbols));
for (Map.Entry<Symbol, Aggregation> aggregationEntry : aggregations.entrySet()) {
result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
}

return result.build();
}

private static Map<Symbol, SymbolStatsEstimate> getGroupBySymbolsStatistics(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols)
{
ImmutableMap.Builder<Symbol, SymbolStatsEstimate> symbolSymbolStatsEstimates = ImmutableMap.builder();
for (Symbol groupBySymbol : groupBySymbols) {
SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
symbolSymbolStatsEstimates.put(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> {
if (nullsFraction == 0.0) {
return 0.0;
}
return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
}));
}
return symbolSymbolStatsEstimates.buildOrThrow();
}

private static SymbolStatsEstimate estimateAggregationStats(Aggregation aggregation, PlanNodeStatsEstimate sourceStats)
{
requireNonNull(aggregation, "aggregation is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ public void testShowStatsWithIntersect()
assertQuery(
"SHOW STATS FOR ((SELECT nationkey FROM nation) INTERSECT (SELECT regionkey FROM region))",
"VALUES " +
" ('nationkey', null, null, null, null, null, null), " +
" (null, null, null, null, null, null, null)");
" ('nationkey', null, 22.5, 0.0, null, 0, 24), " +
" (null, null, null, null, 22.5, null, null)");
}

@Test
Expand All @@ -509,12 +509,11 @@ public void testShowStatsWithAggregation()
@Test
public void testShowStatsWithGroupBy()
{
// TODO calculate row count - https://github.com/trinodb/trino/issues/6323
assertQuery(
"SHOW STATS FOR (SELECT avg(totalprice) AS x FROM orders GROUP BY orderkey)",
"VALUES " +
" ('x', null, null, null, null, null, null), " +
" (null, null, null, null, null, null, null)");
" (null, null, null, null, 15000.0, null, null)");

assertQuery(
sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"),
Expand All @@ -531,7 +530,7 @@ public void testShowStatsWithHaving()
"SHOW STATS FOR (SELECT count(nationkey) AS x FROM nation_partitioned GROUP BY regionkey HAVING regionkey > 0)",
"VALUES " +
" ('x', null, null, null, null, null, null), " +
" (null, null, null, null, null, null, null)");
" (null, null, null, null, 4.0, null, null)");

assertQuery(
sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"),
Expand All @@ -544,20 +543,19 @@ public void testShowStatsWithHaving()
@Test
public void testShowStatsWithSelectDistinct()
{
// TODO calculate row count - https://github.com/trinodb/trino/issues/6323
assertQuery(
"SHOW STATS FOR (SELECT DISTINCT * FROM orders)",
"VALUES " +
" ('orderkey', null, null, null, null, null, null), " +
" ('custkey', null, null, null, null, null, null), " +
" ('orderstatus', null, null, null, null, null, null), " +
" ('totalprice', null, null, null, null, null, null), " +
" ('orderdate', null, null, null, null, null, null), " +
" ('orderpriority', null, null, null, null, null, null), " +
" ('clerk', null, null, null, null, null, null), " +
" ('shippriority', null, null, null, null, null, null), " +
" ('comment', null, null, null, null, null, null), " +
" (null, null, null, null, null, null, null)");
" ('orderkey', null, 15000.0, 0.0, null, '1', '60000'), " +
" ('custkey', null, 990.0, 0.0, null, '1', '1499'), " +
" ('orderstatus', 15000.0, 3.0, 0.0, null, null, null), " +
" ('totalprice', null, 15000.0, 0.0, null, '874.89', '466001.28'), " +
" ('orderdate', null, 2406.0, 0.0, null, '1992-01-01', '1998-08-02'), " +
" ('orderpriority', 126188.00000000001, 5.0, 0.0, null, null, null), " +
" ('clerk', 225000.0, 995.0, 0.0, null, null, null), " +
" ('shippriority', null, 1.0, 0.0, null, '0', '0'), " +
" ('comment', 727364.0, 15000.0, 0.0, null, null, null), " +
" (null, null, null, null, 15000.0, null, null)");

assertQuery(
sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"),
Expand All @@ -572,12 +570,11 @@ public void testShowStatsWithSelectDistinct()
" ('comment', 727364, 15000, 0, null, null, null), " +
" (null, null, null, null, 15000, null, null)");

// TODO calculate row count - https://github.com/trinodb/trino/issues/6323
assertQuery(
"SHOW STATS FOR (SELECT DISTINCT regionkey FROM region)",
"VALUES " +
" ('regionkey', null, null, null, null, null, null), " +
" (null, null, null, null, null, null, null)");
" ('regionkey', null, 5.0, 0.0, null, 0, 4), " +
" (null, null, null, null, 5.0, null, null)");

assertQuery(
sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.testng.annotations.Test;

import static io.trino.SystemSessionProperties.COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES;
import static io.trino.SystemSessionProperties.PREFER_PARTIAL_AGGREGATION;
import static io.trino.plugin.tpch.TpchConnectorFactory.TPCH_COLUMN_NAMING_PROPERTY;
import static io.trino.testing.assertions.Assert.assertEventually;
import static io.trino.testing.statistics.MetricComparisonStrategies.absoluteError;
Expand All @@ -43,8 +42,6 @@ public void setup()
{
DistributedQueryRunner runner = TpchQueryRunnerBuilder.builder()
.amendSession(builder -> builder
// We are not able to calculate stats for PARTIAL aggregations
.setSystemProperty(PREFER_PARTIAL_AGGREGATION, "false")
// Stats for non-EXPLAIN queries are not collected by default
.setSystemProperty(COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES, "true"))
.buildWithoutCatalogs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.testng.annotations.Test;

import static io.trino.SystemSessionProperties.COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES;
import static io.trino.SystemSessionProperties.PREFER_PARTIAL_AGGREGATION;
import static io.trino.plugin.tpch.TpchConnectorFactory.TPCH_COLUMN_NAMING_PROPERTY;
import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME;
import static io.trino.testing.TestingSession.testSessionBuilder;
Expand All @@ -48,8 +47,6 @@ public void setUp()
Session defaultSession = testSessionBuilder()
.setCatalog("tpch")
.setSchema(TINY_SCHEMA_NAME)
// We are not able to calculate stats for PARTIAL aggregations
.setSystemProperty(PREFER_PARTIAL_AGGREGATION, "false")
// Stats for non-EXPLAIN queries are not collected by default
.setSystemProperty(COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES, "true")
.build();
Expand Down

0 comments on commit a22f24c

Please sign in to comment.