From 2520a5a5f516900ce61ff9ae178dc7848a0a01ec Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 25 Feb 2022 11:57:06 +0100 Subject: [PATCH] Add stats rule for DistinctLimitNode --- .../io/trino/cost/DistinctLimitStatsRule.java | 54 +++++++++++++++++++ .../io/trino/cost/StatsCalculatorModule.java | 1 + .../io/trino/plugin/hive/TestShowStats.java | 4 +- 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java diff --git a/core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java new file mode 100644 index 000000000000..54875383819e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java @@ -0,0 +1,54 @@ +/* + * Copyright Starburst Data, Inc. All rights reserved. + * + * THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE OF STARBURST DATA. + * The copyright notice above does not evidence any + * actual or intended publication of such source code. + * + * Redistribution of this material is strictly prohibited. + */ +package io.trino.cost; + +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.matching.Pattern; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.iterative.Lookup; +import io.trino.sql.planner.plan.DistinctLimitNode; + +import java.util.Optional; + +import static io.trino.sql.planner.plan.Patterns.distinctLimit; +import static java.lang.Math.min; + +public class DistinctLimitStatsRule + extends SimpleStatsRule +{ + private static final Pattern PATTERN = distinctLimit(); + + public DistinctLimitStatsRule(StatsNormalizer normalizer) + { + super(normalizer); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + protected Optional doCalculate(DistinctLimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + { + if (node.isPartial()) { + return Optional.empty(); + } + + PlanNodeStatsEstimate distinctStats = AggregationStatsRule.groupBy( + statsProvider.getStats(node.getSource()), + node.getDistinctSymbols(), + ImmutableMap.of()); + PlanNodeStatsEstimate distinctLimitStats = distinctStats.mapOutputRowCount(rowCount -> min(rowCount, node.getLimit())); + return Optional.of(distinctLimitStats); + } +} diff --git a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java index 2a7c502854ca..91fe704c934a 100644 --- a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java +++ b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java @@ -70,6 +70,7 @@ public List> get() rules.add(new FilterStatsRule(normalizer, filterStatsCalculator)); rules.add(new ValuesStatsRule(plannerContext)); rules.add(new LimitStatsRule(normalizer)); + rules.add(new DistinctLimitStatsRule(normalizer)); rules.add(new TopNStatsRule(normalizer)); rules.add(new EnforceSingleRowStatsRule(normalizer)); rules.add(new ProjectStatsRule(scalarStatsCalculator, normalizer)); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java index 707c7187c35e..539263829a71 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestShowStats.java @@ -602,8 +602,8 @@ public void testShowStatsWithDistinctLimit() sessionWith(getSession(), USE_PARTIAL_DISTINCT_LIMIT, "false"), "SHOW STATS FOR (SELECT DISTINCT regionkey FROM nation LIMIT 3)", "VALUES " + - " ('regionkey', null, null, null, null, null, null), " + - " (null, null, null, null, null, null, null)"); + " ('regionkey', null, 3, 0, null, 0, 4), " + + " (null, null, null, null, 3, null, null)"); } @Test