Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve estimation of row count from partition samples #11333

Merged
merged 1 commit into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.math.BigDecimal;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
Expand All @@ -61,6 +62,7 @@
import java.util.Set;
import java.util.stream.DoubleStream;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -392,14 +394,11 @@ private static TableStatistics getTableStatistics(

checkArgument(!partitions.isEmpty(), "partitions is empty");

OptionalDouble optionalAverageRowsPerPartition = calculateAverageRowsPerPartition(statistics.values());
if (optionalAverageRowsPerPartition.isEmpty()) {
Optional<PartitionsRowCount> optionalRowCount = calculatePartitionsRowCount(statistics.values(), partitions.size());
if (optionalRowCount.isEmpty()) {
return TableStatistics.empty();
}
double averageRowsPerPartition = optionalAverageRowsPerPartition.getAsDouble();
verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero");
int queriedPartitionsCount = partitions.size();
double rowCount = averageRowsPerPartition * queriedPartitionsCount;
double rowCount = optionalRowCount.get().getRowCount();

TableStatistics.Builder result = TableStatistics.builder();
result.setRowCount(Estimate.of(rowCount));
Expand All @@ -409,6 +408,7 @@ private static TableStatistics getTableStatistics(
Type columnType = columnTypes.get(columnName);
ColumnStatistics columnStatistics;
if (columnHandle.isPartitionKey()) {
double averageRowsPerPartition = optionalRowCount.get().getAverageRowsPerPartition();
columnStatistics = createPartitionColumnStatistics(columnHandle, columnType, partitions, statistics, averageRowsPerPartition, rowCount);
}
else {
Expand All @@ -420,15 +420,98 @@ private static TableStatistics getTableStatistics(
}

@VisibleForTesting
static OptionalDouble calculateAverageRowsPerPartition(Collection<PartitionStatistics> statistics)
static Optional<PartitionsRowCount> calculatePartitionsRowCount(Collection<PartitionStatistics> statistics, int queriedPartitionsCount)
{
return statistics.stream()
long[] rowCounts = statistics.stream()
.map(PartitionStatistics::getBasicStatistics)
.map(HiveBasicStatistics::getRowCount)
.filter(OptionalLong::isPresent)
.mapToLong(OptionalLong::getAsLong)
.peek(count -> verify(count >= 0, "count must be greater than or equal to zero"))
.average();
.toArray();
int sampleSize = statistics.size();
// Sample contains all the queried partitions, estimate avg normally
if (rowCounts.length <= 2 || queriedPartitionsCount == sampleSize) {
OptionalDouble averageRowsPerPartitionOptional = Arrays.stream(rowCounts).average();
if (averageRowsPerPartitionOptional.isEmpty()) {
return Optional.empty();
}
double averageRowsPerPartition = averageRowsPerPartitionOptional.getAsDouble();
return Optional.of(new PartitionsRowCount(averageRowsPerPartition, averageRowsPerPartition * queriedPartitionsCount));
}

// Some partitions (e.g. __HIVE_DEFAULT_PARTITION__) may be outliers in terms of row count.
sopel39 marked this conversation as resolved.
Show resolved Hide resolved
// Excluding the min and max rowCount values from averageRowsPerPartition calculation helps to reduce the
// possibility of errors in the extrapolated rowCount due to a couple of outliers.
int minIndex = 0;
int maxIndex = 0;
long rowCountSum = rowCounts[0];
for (int index = 1; index < rowCounts.length; index++) {
if (rowCounts[index] < rowCounts[minIndex]) {
minIndex = index;
}
else if (rowCounts[index] > rowCounts[maxIndex]) {
maxIndex = index;
}
rowCountSum += rowCounts[index];
}
double averageWithoutOutliers = ((double) (rowCountSum - rowCounts[minIndex] - rowCounts[maxIndex])) / (rowCounts.length - 2);
double rowCount = (averageWithoutOutliers * (queriedPartitionsCount - 2)) + rowCounts[minIndex] + rowCounts[maxIndex];
raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
return Optional.of(new PartitionsRowCount(averageWithoutOutliers, rowCount));
}

@VisibleForTesting
static class PartitionsRowCount
{
private final double averageRowsPerPartition;
private final double rowCount;

PartitionsRowCount(double averageRowsPerPartition, double rowCount)
{
verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero");
raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
verify(rowCount >= 0, "rowCount must be greater than or equal to zero");
this.averageRowsPerPartition = averageRowsPerPartition;
this.rowCount = rowCount;
}

private double getAverageRowsPerPartition()
{
return averageRowsPerPartition;
}

private double getRowCount()
{
return rowCount;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PartitionsRowCount that = (PartitionsRowCount) o;
return Double.compare(that.averageRowsPerPartition, averageRowsPerPartition) == 0
&& Double.compare(that.rowCount, rowCount) == 0;
}

@Override
public int hashCode()
{
return Objects.hash(averageRowsPerPartition, rowCount);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("averageRowsPerPartition", averageRowsPerPartition)
.add("rowCount", rowCount)
.toString();
}
}

private static ColumnStatistics createPartitionColumnStatistics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@
import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createDecimalColumnStatistics;
import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createDoubleColumnStatistics;
import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createIntegerColumnStatistics;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateAverageRowsPerPartition;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.PartitionsRowCount;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSize;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSizeForPartitioningKey;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctPartitionKeys;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctValuesCount;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFraction;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFractionForPartitioningKey;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculatePartitionsRowCount;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRange;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRangeForPartitioningKey;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.convertPartitionValueToDouble;
Expand All @@ -82,6 +83,7 @@
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.lang.Double.NaN;
import static java.lang.String.format;
import static java.util.Collections.nCopies;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
Expand Down Expand Up @@ -238,15 +240,34 @@ public void testValidatePartitionStatistics()
}

@Test
public void testCalculateAverageRowsPerPartition()
{
assertThat(calculateAverageRowsPerPartition(ImmutableList.of())).isEmpty();
assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty()))).isEmpty();
assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()))).isEmpty();
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10))), OptionalDouble.of(10));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), PartitionStatistics.empty())), OptionalDouble.of(10));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20))), OptionalDouble.of(15));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty())), OptionalDouble.of(15));
public void testCalculatePartitionsRowCount()
{
assertThat(calculatePartitionsRowCount(ImmutableList.of(), 0)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty()), 1)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()), 2)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 1))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 10)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 20)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), PartitionStatistics.empty()), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 20)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 30)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 3))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 45)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty()), 3))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 45)));

assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(100), rowsCount(1000)), 3))
.isEqualTo(Optional.of(new PartitionsRowCount((10 + 100 + 1000) / 3.0, 10 + 100 + 1000)));
// Exclude outliers from average row count
raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
assertThat(calculatePartitionsRowCount(ImmutableList.<PartitionStatistics>builder()
.addAll(nCopies(10, rowsCount(100)))
.add(rowsCount(1))
.add(rowsCount(1000))
.build(),
50))
.isEqualTo(Optional.of(new PartitionsRowCount(100, (100 * 48) + 1 + 1000)));
}

@Test
Expand Down