Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate Iceberg NDV with a Theta sketch #14290

Merged
merged 11 commits into from
Sep 29, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.trino.sql.planner.plan.StatisticAggregations;
import io.trino.sql.planner.plan.StatisticAggregationsDescriptor;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SymbolReference;

import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -122,34 +121,34 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta
private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType statisticType, Symbol input, Type inputType)
{
return switch (statisticType) {
case MIN_VALUE -> createAggregation(QualifiedName.of("min"), input.toSymbolReference(), inputType);
case MAX_VALUE -> createAggregation(QualifiedName.of("max"), input.toSymbolReference(), inputType);
case NUMBER_OF_DISTINCT_VALUES -> createAggregation(QualifiedName.of("approx_distinct"), input.toSymbolReference(), inputType);
case MIN_VALUE -> createAggregation(QualifiedName.of("min"), input, inputType);
case MAX_VALUE -> createAggregation(QualifiedName.of("max"), input, inputType);
case NUMBER_OF_DISTINCT_VALUES -> createAggregation(QualifiedName.of("approx_distinct"), input, inputType);
case NUMBER_OF_DISTINCT_VALUES_SUMMARY ->
// we use $approx_set here and not approx_set because latter is not defined for all types supported by Trino
createAggregation(QualifiedName.of("$approx_set"), input.toSymbolReference(), inputType);
case NUMBER_OF_NON_NULL_VALUES -> createAggregation(QualifiedName.of("count"), input.toSymbolReference(), inputType);
case NUMBER_OF_TRUE_VALUES -> createAggregation(QualifiedName.of("count_if"), input.toSymbolReference(), BOOLEAN);
case TOTAL_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(SumDataSizeForStats.NAME), input.toSymbolReference(), inputType);
case MAX_VALUE_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(MaxDataSizeForStats.NAME), input.toSymbolReference(), inputType);
createAggregation(QualifiedName.of("$approx_set"), input, inputType);
case NUMBER_OF_NON_NULL_VALUES -> createAggregation(QualifiedName.of("count"), input, inputType);
case NUMBER_OF_TRUE_VALUES -> createAggregation(QualifiedName.of("count_if"), input, BOOLEAN);
case TOTAL_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(SumDataSizeForStats.NAME), input, inputType);
case MAX_VALUE_SIZE_IN_BYTES -> createAggregation(QualifiedName.of(MaxDataSizeForStats.NAME), input, inputType);
};
}

private ColumnStatisticsAggregation createColumnAggregation(FunctionName aggregation, Symbol input, Type inputType)
{
checkArgument(aggregation.getCatalogSchema().isEmpty(), "Catalog/schema name not supported");
return createAggregation(QualifiedName.of(aggregation.getName()), input.toSymbolReference(), inputType);
return createAggregation(QualifiedName.of(aggregation.getName()), input, inputType);
}

private ColumnStatisticsAggregation createAggregation(QualifiedName functionName, SymbolReference input, Type inputType)
private ColumnStatisticsAggregation createAggregation(QualifiedName functionName, Symbol input, Type inputType)
{
ResolvedFunction resolvedFunction = metadata.resolveFunction(session, functionName, fromTypes(inputType));
Type resolvedType = getOnlyElement(resolvedFunction.getSignature().getArgumentTypes());
verify(resolvedType.equals(inputType), "resolved function input type does not match the input type: %s != %s", resolvedType, inputType);
return new ColumnStatisticsAggregation(
new AggregationNode.Aggregation(
resolvedFunction,
ImmutableList.of(input),
ImmutableList.of(input.toSymbolReference()),
false,
Optional.empty(),
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public class MaterializedResult

private final List<MaterializedRow> rows;
private final List<Type> types;
private final List<String> columnNames;
private final Map<String, String> setSessionProperties;
private final Set<String> resetSessionProperties;
private final Optional<String> updateType;
Expand All @@ -100,12 +101,13 @@ public class MaterializedResult

public MaterializedResult(List<MaterializedRow> rows, List<? extends Type> types)
{
this(rows, types, ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty());
this(rows, types, ImmutableList.of(), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty());
}

public MaterializedResult(
List<MaterializedRow> rows,
List<? extends Type> types,
List<String> columnNames,
Map<String, String> setSessionProperties,
Set<String> resetSessionProperties,
Optional<String> updateType,
Expand All @@ -115,6 +117,7 @@ public MaterializedResult(
{
this.rows = ImmutableList.copyOf(requireNonNull(rows, "rows is null"));
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null"));
this.setSessionProperties = ImmutableMap.copyOf(requireNonNull(setSessionProperties, "setSessionProperties is null"));
this.resetSessionProperties = ImmutableSet.copyOf(requireNonNull(resetSessionProperties, "resetSessionProperties is null"));
this.updateType = requireNonNull(updateType, "updateType is null");
Expand Down Expand Up @@ -144,6 +147,12 @@ public List<Type> getTypes()
return types;
}

public List<String> getColumnNames()
{
checkState(!columnNames.isEmpty(), "Column names are unknown");
return columnNames;
}

public Map<String, String> getSetSessionProperties()
{
return setSessionProperties;
Expand Down Expand Up @@ -362,6 +371,7 @@ public MaterializedResult toTestTypes()
.map(MaterializedResult::convertToTestTypes)
.collect(toImmutableList()),
types,
columnNames,
setSessionProperties,
resetSessionProperties,
updateType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ private QueryAssert(
this.skipResultsCorrectnessCheckForPushdown = skipResultsCorrectnessCheckForPushdown;
}

// TODO for better readability, replace this with `exceptColumns(String... columnNamesToExclude)` leveraging MaterializedResult.getColumnNames
@Deprecated
public QueryAssert projected(int... columns)
{
return new QueryAssert(
Expand Down
12 changes: 12 additions & 0 deletions plugin/trino-iceberg/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@
<artifactId>failsafe</artifactId>
</dependency>

<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-java</artifactId>
<version>3.3.0</version>
</dependency>

<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-memory</artifactId>
<version>2.1.0</version>
</dependency>

<dependency>
<groupId>org.apache.iceberg</groupId>
<artifactId>iceberg-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,16 @@
package io.trino.plugin.iceberg;

import com.google.common.base.VerifyException;
import io.airlift.slice.Slice;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.UuidType;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expressions;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -48,16 +33,8 @@

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.plugin.iceberg.IcebergMetadataColumn.isMetadataColumnId;
import static io.trino.plugin.iceberg.util.Timestamps.timestampTzToMicros;
import static io.trino.spi.type.TimeType.TIME_MICROS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS;
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND;
import static io.trino.spi.type.UuidType.trinoUuidToJavaUuid;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static io.trino.plugin.iceberg.IcebergTypes.convertTrinoValueToIceberg;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static org.apache.iceberg.expressions.Expressions.alwaysFalse;
import static org.apache.iceberg.expressions.Expressions.alwaysTrue;
import static org.apache.iceberg.expressions.Expressions.equal;
Expand Down Expand Up @@ -117,7 +94,7 @@ private static Expression toIcebergExpression(String columnName, Type type, Doma
List<Expression> rangeExpressions = new ArrayList<>();
for (Range range : orderedRanges) {
if (range.isSingleValue()) {
icebergValues.add(getIcebergLiteralValue(type, range.getLowBoundedValue()));
icebergValues.add(convertTrinoValueToIceberg(type, range.getLowBoundedValue()));
}
else {
rangeExpressions.add(toIcebergExpression(columnName, range));
Expand All @@ -137,13 +114,13 @@ private static Expression toIcebergExpression(String columnName, Range range)
Type type = range.getType();

if (range.isSingleValue()) {
Object icebergValue = getIcebergLiteralValue(type, range.getSingleValue());
Object icebergValue = convertTrinoValueToIceberg(type, range.getSingleValue());
return equal(columnName, icebergValue);
}

List<Expression> conjuncts = new ArrayList<>(2);
if (!range.isLowUnbounded()) {
Object icebergLow = getIcebergLiteralValue(type, range.getLowBoundedValue());
Object icebergLow = convertTrinoValueToIceberg(type, range.getLowBoundedValue());
Expression lowBound;
if (range.isLowInclusive()) {
lowBound = greaterThanOrEqual(columnName, icebergLow);
Expand All @@ -155,7 +132,7 @@ private static Expression toIcebergExpression(String columnName, Range range)
}

if (!range.isHighUnbounded()) {
Object icebergHigh = getIcebergLiteralValue(type, range.getHighBoundedValue());
Object icebergHigh = convertTrinoValueToIceberg(type, range.getHighBoundedValue());
Expression highBound;
if (range.isHighInclusive()) {
highBound = lessThanOrEqual(columnName, icebergHigh);
Expand All @@ -169,68 +146,6 @@ private static Expression toIcebergExpression(String columnName, Range range)
return and(conjuncts);
}

private static Object getIcebergLiteralValue(Type type, Object trinoNativeValue)
{
requireNonNull(trinoNativeValue, "trinoNativeValue is null");

if (type instanceof BooleanType) {
return (boolean) trinoNativeValue;
}

if (type instanceof IntegerType) {
return toIntExact((long) trinoNativeValue);
}

if (type instanceof BigintType) {
return (long) trinoNativeValue;
}

if (type instanceof RealType) {
return intBitsToFloat(toIntExact((long) trinoNativeValue));
}

if (type instanceof DoubleType) {
return (double) trinoNativeValue;
}

if (type instanceof DateType) {
return toIntExact(((Long) trinoNativeValue));
}

if (type.equals(TIME_MICROS)) {
return ((long) trinoNativeValue) / PICOSECONDS_PER_MICROSECOND;
}

if (type.equals(TIMESTAMP_MICROS)) {
return (long) trinoNativeValue;
}

if (type.equals(TIMESTAMP_TZ_MICROS)) {
return timestampTzToMicros((LongTimestampWithTimeZone) trinoNativeValue);
}

if (type instanceof VarcharType) {
return ((Slice) trinoNativeValue).toStringUtf8();
}

if (type instanceof VarbinaryType) {
return ByteBuffer.wrap(((Slice) trinoNativeValue).getBytes());
}

if (type instanceof UuidType) {
return trinoUuidToJavaUuid(((Slice) trinoNativeValue));
}

if (type instanceof DecimalType decimalType) {
if (decimalType.isShort()) {
return BigDecimal.valueOf((long) trinoNativeValue).movePointLeft(decimalType.getScale());
}
return new BigDecimal(((Int128) trinoNativeValue).toBigInteger(), decimalType.getScale());
}

throw new UnsupportedOperationException("Unsupported type: " + type);
}

private static Expression and(List<Expression> expressions)
{
if (expressions.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import io.trino.plugin.hive.HiveApplyProjectionUtil;
import io.trino.plugin.hive.HiveApplyProjectionUtil.ProjectedColumnRepresentation;
import io.trino.plugin.hive.HiveWrittenPartitions;
import io.trino.plugin.iceberg.aggregation.DataSketchStateSerializer;
import io.trino.plugin.iceberg.aggregation.IcebergThetaSketchForStats;
import io.trino.plugin.iceberg.catalog.TrinoCatalog;
import io.trino.plugin.iceberg.procedure.IcebergDropExtendedStatsHandle;
import io.trino.plugin.iceberg.procedure.IcebergExpireSnapshotsHandle;
Expand Down Expand Up @@ -94,6 +96,7 @@
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeOperators;
import org.apache.datasketches.theta.CompactSketch;
import org.apache.iceberg.AppendFiles;
import org.apache.iceberg.BaseTable;
import org.apache.iceberg.DataFile;
Expand Down Expand Up @@ -229,7 +232,6 @@
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.connector.RetryMode.NO_RETRIES;
import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW;
import static io.trino.spi.predicate.Utils.blockToNativeValue;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.UuidType.UUID;
Expand Down Expand Up @@ -264,7 +266,7 @@ public class IcebergMetadata
public static final String ORC_BLOOM_FILTER_FPP_KEY = "orc.bloom.filter.fpp";

private static final String NUMBER_OF_DISTINCT_VALUES_NAME = "NUMBER_OF_DISTINCT_VALUES";
private static final FunctionName NUMBER_OF_DISTINCT_VALUES_FUNCTION = new FunctionName("approx_distinct");
private static final FunctionName NUMBER_OF_DISTINCT_VALUES_FUNCTION = new FunctionName(IcebergThetaSketchForStats.NAME);

private final TypeManager typeManager;
private final TypeOperators typeOperators;
Expand Down Expand Up @@ -1468,8 +1470,9 @@ public ConnectorAnalyzeMetadata getStatisticsCollectionMetadata(ConnectorSession
}

ConnectorTableMetadata tableMetadata = getTableMetadata(session, handle);
Set<String> allDataColumnNames = tableMetadata.getColumns().stream()
Set<String> allScalarColumnNames = tableMetadata.getColumns().stream()
.filter(column -> !column.isHidden())
.filter(column -> column.getType().getTypeParameters().isEmpty()) // is scalar type
.map(ColumnMetadata::getName)
.collect(toImmutableSet());

Expand All @@ -1479,18 +1482,17 @@ public ConnectorAnalyzeMetadata getStatisticsCollectionMetadata(ConnectorSession
if (columnNames.isEmpty()) {
throw new TrinoException(INVALID_ANALYZE_PROPERTY, "Cannot specify empty list of columns for analysis");
}
if (!allDataColumnNames.containsAll(columnNames)) {
if (!allScalarColumnNames.containsAll(columnNames)) {
throw new TrinoException(
INVALID_ANALYZE_PROPERTY,
format("Invalid columns specified for analysis: %s", Sets.difference(columnNames, allDataColumnNames)));
format("Invalid columns specified for analysis: %s", Sets.difference(columnNames, allScalarColumnNames)));
}
return columnNames;
})
.orElse(allDataColumnNames);
.orElse(allScalarColumnNames);

Set<ColumnStatisticMetadata> columnStatistics = tableMetadata.getColumns().stream()
.filter(column -> analyzeColumnNames.contains(column.getName()))
// TODO: add support for NDV summary/sketch, but using Theta sketch, not HLL; see https://github.com/apache/iceberg-docs/pull/69
.map(column -> new ColumnStatisticMetadata(column.getName(), NUMBER_OF_DISTINCT_VALUES_NAME, NUMBER_OF_DISTINCT_VALUES_FUNCTION))
.collect(toImmutableSet());

Expand Down Expand Up @@ -1537,12 +1539,13 @@ public void finishStatisticsCollection(ConnectorSession session, ConnectorTableH
for (Map.Entry<ColumnStatisticMetadata, Block> entry : computedStatistic.getColumnStatistics().entrySet()) {
ColumnStatisticMetadata statisticMetadata = entry.getKey();
if (statisticMetadata.getConnectorAggregationId().equals(NUMBER_OF_DISTINCT_VALUES_NAME)) {
long ndv = (long) blockToNativeValue(BIGINT, entry.getValue());
Integer columnId = verifyNotNull(
columnNameToId.get(statisticMetadata.getColumnName()),
"Column not found in table: [%s]",
statisticMetadata.getColumnName());
updateProperties.set(TRINO_STATS_NDV_FORMAT.formatted(columnId), Long.toString(ndv));
CompactSketch sketch = DataSketchStateSerializer.deserialize(entry.getValue(), 0);
// TODO: store whole sketch to support updates, see also https://github.com/apache/iceberg-docs/pull/69
findepi marked this conversation as resolved.
Show resolved Hide resolved
updateProperties.set(TRINO_STATS_NDV_FORMAT.formatted(columnId), Long.toString((long) sketch.getEstimate()));
}
else {
throw new UnsupportedOperationException("Unsupported statistic: " + statisticMetadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
package io.trino.plugin.iceberg;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.plugin.iceberg.aggregation.IcebergThetaSketchForStats;
import io.trino.spi.Plugin;
import io.trino.spi.connector.ConnectorFactory;

import java.util.Set;

public class IcebergPlugin
implements Plugin
{
Expand All @@ -25,4 +29,10 @@ public Iterable<ConnectorFactory> getConnectorFactories()
{
return ImmutableList.of(new IcebergConnectorFactory());
}

@Override
public Set<Class<?>> getFunctions()
{
return ImmutableSet.of(IcebergThetaSketchForStats.class);
}
}
Loading