Skip to content

Commit

Permalink
Improve estimation of row count from partition samples
Browse files Browse the repository at this point in the history
Reduce the possiblity of estimation errors in averageRowsPerPartition
and rowCount due to a couple of outliers by excluding the
min and max rowCount values from the calculation of
avg rows per partition.
  • Loading branch information
raunaqmorarka committed Mar 8, 2022
1 parent 98b794a commit d3ea6a9
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.math.BigDecimal;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
Expand All @@ -61,6 +62,7 @@
import java.util.Set;
import java.util.stream.DoubleStream;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -392,14 +394,11 @@ private static TableStatistics getTableStatistics(

checkArgument(!partitions.isEmpty(), "partitions is empty");

OptionalDouble optionalAverageRowsPerPartition = calculateAverageRowsPerPartition(statistics.values());
if (optionalAverageRowsPerPartition.isEmpty()) {
Optional<PartitionsRowCount> optionalRowCount = calculatePartitionsRowCount(statistics.values(), partitions.size());
if (optionalRowCount.isEmpty()) {
return TableStatistics.empty();
}
double averageRowsPerPartition = optionalAverageRowsPerPartition.getAsDouble();
verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero");
int queriedPartitionsCount = partitions.size();
double rowCount = averageRowsPerPartition * queriedPartitionsCount;
double rowCount = optionalRowCount.get().getRowCount();

TableStatistics.Builder result = TableStatistics.builder();
result.setRowCount(Estimate.of(rowCount));
Expand All @@ -409,6 +408,7 @@ private static TableStatistics getTableStatistics(
Type columnType = columnTypes.get(columnName);
ColumnStatistics columnStatistics;
if (columnHandle.isPartitionKey()) {
double averageRowsPerPartition = optionalRowCount.get().getAverageRowsPerPartition();
columnStatistics = createPartitionColumnStatistics(columnHandle, columnType, partitions, statistics, averageRowsPerPartition, rowCount);
}
else {
Expand All @@ -420,15 +420,98 @@ private static TableStatistics getTableStatistics(
}

@VisibleForTesting
static OptionalDouble calculateAverageRowsPerPartition(Collection<PartitionStatistics> statistics)
static Optional<PartitionsRowCount> calculatePartitionsRowCount(Collection<PartitionStatistics> statistics, int queriedPartitionsCount)
{
return statistics.stream()
long[] rowCounts = statistics.stream()
.map(PartitionStatistics::getBasicStatistics)
.map(HiveBasicStatistics::getRowCount)
.filter(OptionalLong::isPresent)
.mapToLong(OptionalLong::getAsLong)
.peek(count -> verify(count >= 0, "count must be greater than or equal to zero"))
.average();
.toArray();
int sampleSize = statistics.size();
// Sample contains all the queried partitions, estimate avg normally
if (rowCounts.length <= 2 || queriedPartitionsCount == sampleSize) {
OptionalDouble averageRowsPerPartitionOptional = Arrays.stream(rowCounts).average();
if (averageRowsPerPartitionOptional.isEmpty()) {
return Optional.empty();
}
double averageRowsPerPartition = averageRowsPerPartitionOptional.getAsDouble();
return Optional.of(new PartitionsRowCount(averageRowsPerPartition, averageRowsPerPartition * queriedPartitionsCount));
}

// Some partitions (e.g. __HIVE_DEFAULT_PARTITION__) may be outliers in terms of row count.
// Excluding the min and max rowCount values from averageRowsPerPartition calculation helps to reduce the
// possibility of errors in the extrapolated rowCount due to a couple of outliers.
int minIndex = 0;
int maxIndex = 0;
long rowCountSum = rowCounts[0];
for (int index = 1; index < rowCounts.length; index++) {
if (rowCounts[index] < rowCounts[minIndex]) {
minIndex = index;
}
else if (rowCounts[index] > rowCounts[maxIndex]) {
maxIndex = index;
}
rowCountSum += rowCounts[index];
}
double averageWithoutOutliers = ((double) (rowCountSum - rowCounts[minIndex] - rowCounts[maxIndex])) / (rowCounts.length - 2);
double rowCount = (averageWithoutOutliers * (queriedPartitionsCount - 2)) + rowCounts[minIndex] + rowCounts[maxIndex];
return Optional.of(new PartitionsRowCount(averageWithoutOutliers, rowCount));
}

@VisibleForTesting
static class PartitionsRowCount
{
private final double averageRowsPerPartition;
private final double rowCount;

PartitionsRowCount(double averageRowsPerPartition, double rowCount)
{
verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero");
verify(rowCount >= 0, "rowCount must be greater than or equal to zero");
this.averageRowsPerPartition = averageRowsPerPartition;
this.rowCount = rowCount;
}

private double getAverageRowsPerPartition()
{
return averageRowsPerPartition;
}

private double getRowCount()
{
return rowCount;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PartitionsRowCount that = (PartitionsRowCount) o;
return Double.compare(that.averageRowsPerPartition, averageRowsPerPartition) == 0
&& Double.compare(that.rowCount, rowCount) == 0;
}

@Override
public int hashCode()
{
return Objects.hash(averageRowsPerPartition, rowCount);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("averageRowsPerPartition", averageRowsPerPartition)
.add("rowCount", rowCount)
.toString();
}
}

private static ColumnStatistics createPartitionColumnStatistics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@
import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createDecimalColumnStatistics;
import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createDoubleColumnStatistics;
import static io.trino.plugin.hive.metastore.HiveColumnStatistics.createIntegerColumnStatistics;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateAverageRowsPerPartition;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.PartitionsRowCount;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSize;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSizeForPartitioningKey;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctPartitionKeys;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctValuesCount;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFraction;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFractionForPartitioningKey;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculatePartitionsRowCount;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRange;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRangeForPartitioningKey;
import static io.trino.plugin.hive.statistics.MetastoreHiveStatisticsProvider.convertPartitionValueToDouble;
Expand All @@ -82,6 +83,7 @@
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.lang.Double.NaN;
import static java.lang.String.format;
import static java.util.Collections.nCopies;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
Expand Down Expand Up @@ -238,15 +240,34 @@ public void testValidatePartitionStatistics()
}

@Test
public void testCalculateAverageRowsPerPartition()
{
assertThat(calculateAverageRowsPerPartition(ImmutableList.of())).isEmpty();
assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty()))).isEmpty();
assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()))).isEmpty();
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10))), OptionalDouble.of(10));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), PartitionStatistics.empty())), OptionalDouble.of(10));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20))), OptionalDouble.of(15));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty())), OptionalDouble.of(15));
public void testCalculatePartitionsRowCount()
{
assertThat(calculatePartitionsRowCount(ImmutableList.of(), 0)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty()), 1)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()), 2)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 1))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 10)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 20)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), PartitionStatistics.empty()), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 20)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 30)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 3))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 45)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty()), 3))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 45)));

assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(100), rowsCount(1000)), 3))
.isEqualTo(Optional.of(new PartitionsRowCount((10 + 100 + 1000) / 3.0, 10 + 100 + 1000)));
// Exclude outliers from average row count
assertThat(calculatePartitionsRowCount(ImmutableList.<PartitionStatistics>builder()
.addAll(nCopies(10, rowsCount(100)))
.add(rowsCount(1))
.add(rowsCount(1000))
.build(),
50))
.isEqualTo(Optional.of(new PartitionsRowCount(100, (100 * 48) + 1 + 1000)));
}

@Test
Expand Down

0 comments on commit d3ea6a9

Please sign in to comment.