diff --git a/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java index 4f6eea230c8a..7c97d313797c 100644 --- a/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java @@ -13,6 +13,7 @@ */ package io.trino.cost; +import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.matching.Pattern; import io.trino.sql.planner.Symbol; @@ -25,8 +26,8 @@ import java.util.Map; import java.util.Optional; -import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; -import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.Step.INTERMEDIATE; +import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.Patterns.aggregation; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @@ -50,36 +51,35 @@ public Pattern getPattern() @Override protected Optional doCalculate(AggregationNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { - if (node.getGroupingSetCount() != 1) { + if (node.getGroupingSetCount() != 1 || node.getStep() == INTERMEDIATE) { return Optional.empty(); } - if (node.getStep() != SINGLE && node.getStep() != FINAL) { - return Optional.empty(); - } + PlanNodeStatsEstimate estimate; - return Optional.of(groupBy( - statsProvider.getStats(node.getSource()), - node.getGroupingKeys(), - node.getAggregations())); + if (node.getStep() == PARTIAL) { + estimate = partialGroupBy(statsProvider.getStats(node.getSource()), + node.getGroupingKeys(), + node.getAggregations()); + } + else { + estimate = groupBy( + statsProvider.getStats(node.getSource()), + node.getGroupingKeys(), + node.getAggregations()); + } + return Optional.of(estimate); } public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection groupBySymbols, Map aggregations) { + // Used to estimate FINAL or SINGLE step aggregations PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); if (groupBySymbols.isEmpty()) { result.setOutputRowCount(1); } else { - for (Symbol groupBySymbol : groupBySymbols) { - SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol); - result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> { - if (nullsFraction == 0.0) { - return 0.0; - } - return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1); - })); - } + result.addSymbolStatistics(getGroupBySymbolsStatistics(sourceStats, groupBySymbols)); double rowsCount = getRowsCount(sourceStats, groupBySymbols); result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount())); } @@ -101,6 +101,35 @@ public static double getRowsCount(PlanNodeStatsEstimate sourceStats, Collection< return rowsCount; } + private static PlanNodeStatsEstimate partialGroupBy(PlanNodeStatsEstimate sourceStats, Collection groupBySymbols, Map aggregations) + { + // Pessimistic assumption of no reduction from PARTIAL aggregation, forwarding of the source statistics. This makes the CBO estimates in the EXPLAIN plan output easier to understand, + // even though partial aggregations are added after the CBO rules have been run. + PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); + result.setOutputRowCount(sourceStats.getOutputRowCount()); + result.addSymbolStatistics(getGroupBySymbolsStatistics(sourceStats, groupBySymbols)); + for (Map.Entry aggregationEntry : aggregations.entrySet()) { + result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats)); + } + + return result.build(); + } + + private static Map getGroupBySymbolsStatistics(PlanNodeStatsEstimate sourceStats, Collection groupBySymbols) + { + ImmutableMap.Builder symbolSymbolStatsEstimates = ImmutableMap.builder(); + for (Symbol groupBySymbol : groupBySymbols) { + SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol); + symbolSymbolStatsEstimates.put(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> { + if (nullsFraction == 0.0) { + return 0.0; + } + return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1); + })); + } + return symbolSymbolStatsEstimates.buildOrThrow(); + } + private static SymbolStatsEstimate estimateAggregationStats(Aggregation aggregation, PlanNodeStatsEstimate sourceStats) { requireNonNull(aggregation, "aggregation is null"); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java index 2b72d1732aca..471d45db9ff8 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java @@ -485,8 +485,8 @@ public void testShowStatsWithIntersect() assertQuery( "SHOW STATS FOR ((SELECT nationkey FROM nation) INTERSECT (SELECT regionkey FROM region))", "VALUES " + - " ('nationkey', null, null, null, null, null, null), " + - " (null, null, null, null, null, null, null)"); + " ('nationkey', null, 22.5, 0.0, null, 0, 24), " + + " (null, null, null, null, 22.5, null, null)"); } @Test @@ -509,12 +509,11 @@ public void testShowStatsWithAggregation() @Test public void testShowStatsWithGroupBy() { - // TODO calculate row count - https://github.com/trinodb/trino/issues/6323 assertQuery( "SHOW STATS FOR (SELECT avg(totalprice) AS x FROM orders GROUP BY orderkey)", "VALUES " + " ('x', null, null, null, null, null, null), " + - " (null, null, null, null, null, null, null)"); + " (null, null, null, null, 15000.0, null, null)"); assertQuery( sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"), @@ -531,7 +530,7 @@ public void testShowStatsWithHaving() "SHOW STATS FOR (SELECT count(nationkey) AS x FROM nation_partitioned GROUP BY regionkey HAVING regionkey > 0)", "VALUES " + " ('x', null, null, null, null, null, null), " + - " (null, null, null, null, null, null, null)"); + " (null, null, null, null, 4.0, null, null)"); assertQuery( sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"), @@ -544,20 +543,19 @@ public void testShowStatsWithHaving() @Test public void testShowStatsWithSelectDistinct() { - // TODO calculate row count - https://github.com/trinodb/trino/issues/6323 assertQuery( "SHOW STATS FOR (SELECT DISTINCT * FROM orders)", "VALUES " + - " ('orderkey', null, null, null, null, null, null), " + - " ('custkey', null, null, null, null, null, null), " + - " ('orderstatus', null, null, null, null, null, null), " + - " ('totalprice', null, null, null, null, null, null), " + - " ('orderdate', null, null, null, null, null, null), " + - " ('orderpriority', null, null, null, null, null, null), " + - " ('clerk', null, null, null, null, null, null), " + - " ('shippriority', null, null, null, null, null, null), " + - " ('comment', null, null, null, null, null, null), " + - " (null, null, null, null, null, null, null)"); + " ('orderkey', null, 15000.0, 0.0, null, '1', '60000'), " + + " ('custkey', null, 990.0, 0.0, null, '1', '1499'), " + + " ('orderstatus', 15000.0, 3.0, 0.0, null, null, null), " + + " ('totalprice', null, 15000.0, 0.0, null, '874.89', '466001.28'), " + + " ('orderdate', null, 2406.0, 0.0, null, '1992-01-01', '1998-08-02'), " + + " ('orderpriority', 126188.00000000001, 5.0, 0.0, null, null, null), " + + " ('clerk', 225000.0, 995.0, 0.0, null, null, null), " + + " ('shippriority', null, 1.0, 0.0, null, '0', '0'), " + + " ('comment', 727364.0, 15000.0, 0.0, null, null, null), " + + " (null, null, null, null, 15000.0, null, null)"); assertQuery( sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"), @@ -572,12 +570,11 @@ public void testShowStatsWithSelectDistinct() " ('comment', 727364, 15000, 0, null, null, null), " + " (null, null, null, null, 15000, null, null)"); - // TODO calculate row count - https://github.com/trinodb/trino/issues/6323 assertQuery( "SHOW STATS FOR (SELECT DISTINCT regionkey FROM region)", "VALUES " + - " ('regionkey', null, null, null, null, null, null), " + - " (null, null, null, null, null, null, null)"); + " ('regionkey', null, 5.0, 0.0, null, 0, 4), " + + " (null, null, null, null, 5.0, null, null)"); assertQuery( sessionWith(getSession(), PREFER_PARTIAL_AGGREGATION, "false"), diff --git a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java index b4f207cb2ec9..08b2544f085e 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchDistributedStats.java @@ -23,7 +23,6 @@ import org.testng.annotations.Test; import static io.trino.SystemSessionProperties.COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES; -import static io.trino.SystemSessionProperties.PREFER_PARTIAL_AGGREGATION; import static io.trino.plugin.tpch.TpchConnectorFactory.TPCH_COLUMN_NAMING_PROPERTY; import static io.trino.testing.assertions.Assert.assertEventually; import static io.trino.testing.statistics.MetricComparisonStrategies.absoluteError; @@ -43,8 +42,6 @@ public void setup() { DistributedQueryRunner runner = TpchQueryRunnerBuilder.builder() .amendSession(builder -> builder - // We are not able to calculate stats for PARTIAL aggregations - .setSystemProperty(PREFER_PARTIAL_AGGREGATION, "false") // Stats for non-EXPLAIN queries are not collected by default .setSystemProperty(COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES, "true")) .buildWithoutCatalogs(); diff --git a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java index 2fd31e373b70..dc3e2f621ae3 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/tpch/TestTpchLocalStats.java @@ -24,7 +24,6 @@ import org.testng.annotations.Test; import static io.trino.SystemSessionProperties.COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES; -import static io.trino.SystemSessionProperties.PREFER_PARTIAL_AGGREGATION; import static io.trino.plugin.tpch.TpchConnectorFactory.TPCH_COLUMN_NAMING_PROPERTY; import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -48,8 +47,6 @@ public void setUp() Session defaultSession = testSessionBuilder() .setCatalog("tpch") .setSchema(TINY_SCHEMA_NAME) - // We are not able to calculate stats for PARTIAL aggregations - .setSystemProperty(PREFER_PARTIAL_AGGREGATION, "false") // Stats for non-EXPLAIN queries are not collected by default .setSystemProperty(COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES, "true") .build();