Skip to content

Commit

Permalink
Convert qdigest_agg and merge aggregations to annotated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent fbf075e commit ff24f6a
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 201 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
import io.trino.operator.aggregation.MergeQuantileDigestFunction;
import io.trino.operator.aggregation.MergeTDigestAggregation;
import io.trino.operator.aggregation.MinNAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.BigintQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.DoubleQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.RealQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.RealCorrelationAggregation;
import io.trino.operator.aggregation.RealCovarianceAggregation;
import io.trino.operator.aggregation.RealGeometricMeanAggregations;
Expand Down Expand Up @@ -264,9 +267,6 @@
import static io.trino.operator.aggregation.DecimalSumAggregation.DECIMAL_SUM_AGGREGATION;
import static io.trino.operator.aggregation.MaxAggregationFunction.MAX_AGGREGATION;
import static io.trino.operator.aggregation.MinAggregationFunction.MIN_AGGREGATION;
import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG;
import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG_WITH_WEIGHT;
import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG_WITH_WEIGHT_AND_ERROR;
import static io.trino.operator.aggregation.RealAverageAggregation.REAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG;
import static io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction.ARRAY_AGG;
Expand Down Expand Up @@ -402,8 +402,10 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.aggregates(ApproximateSetAggregation.class)
.aggregates(ApproximateSetGenericAggregation.class)
.aggregates(TDigestAggregationFunction.class)
.functions(QDIGEST_AGG, QDIGEST_AGG_WITH_WEIGHT, QDIGEST_AGG_WITH_WEIGHT_AND_ERROR)
.function(MergeQuantileDigestFunction.MERGE)
.aggregates(DoubleQuantileDigestAggregationFunction.class)
.aggregates(RealQuantileDigestAggregationFunction.class)
.aggregates(BigintQuantileDigestAggregationFunction.class)
.aggregates(MergeQuantileDigestFunction.class)
.aggregates(MergeTDigestAggregation.class)
.aggregates(DoubleHistogramAggregation.class)
.aggregates(RealHistogramAggregation.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,90 +13,46 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.airlift.stats.QuantileDigest;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.state.QuantileDigestState;
import io.trino.operator.aggregation.state.QuantileDigestStateFactory;
import io.trino.operator.aggregation.state.QuantileDigestStateSerializer;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.type.QuantileDigestType;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import static io.trino.spi.type.StandardTypes.QDIGEST;
import static io.trino.spi.type.TypeSignature.parametricType;
import static io.trino.util.MoreMath.nearlyEqual;
import static io.trino.util.Reflection.methodHandle;

@AggregationFunction("merge")
@AggregationFunction(value = "merge", isOrderSensitive = true)
@Description("Merges the input quantile digests into a single quantile digest")
public final class MergeQuantileDigestFunction
extends SqlAggregationFunction
{
public static final MergeQuantileDigestFunction MERGE = new MergeQuantileDigestFunction();
public static final String NAME = "merge";
private static final MethodHandle INPUT_FUNCTION = methodHandle(MergeQuantileDigestFunction.class, "input", Type.class, QuantileDigestState.class, Block.class, int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(MergeQuantileDigestFunction.class, "combine", QuantileDigestState.class, QuantileDigestState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(MergeQuantileDigestFunction.class, "output", QuantileDigestStateSerializer.class, QuantileDigestState.class, BlockBuilder.class);
private static final double COMPARISON_EPSILON = 1E-6;

public MergeQuantileDigestFunction()
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.comparableTypeParameter("T")
.returnType(parametricType(QDIGEST, new TypeSignature("T")))
.argumentType(parametricType(QDIGEST, new TypeSignature("T")))
.build())
.description("Merges the input quantile digests into a single quantile digest")
.build(),
AggregationFunctionMetadata.builder()
.orderSensitive()
.intermediateType(parametricType(QDIGEST, new TypeSignature("T")))
.build());
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
QuantileDigestType outputType = (QuantileDigestType) boundSignature.getReturnType();
Type valueType = outputType.getValueType();
QuantileDigestStateSerializer stateSerializer = new QuantileDigestStateSerializer(valueType);
private MergeQuantileDigestFunction() {}

return new AggregationMetadata(
INPUT_FUNCTION.bindTo(outputType),
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
OUTPUT_FUNCTION.bindTo(stateSerializer),
ImmutableList.of(new AccumulatorStateDescriptor<>(
QuantileDigestState.class,
stateSerializer,
new QuantileDigestStateFactory())));
}
private static final double COMPARISON_EPSILON = 1.0E-6;

@InputFunction
public static void input(Type type, QuantileDigestState state, Block value, int index)
@TypeParameter("V")
public static void input(
@TypeParameter("V") Type type,
@AggregationState QuantileDigestState state,
@BlockPosition @SqlType("V") Block value,
@BlockIndex int index)
{
merge(state, new QuantileDigest(type.getSlice(value, index)));
}

@CombineFunction
public static void combine(QuantileDigestState state, QuantileDigestState otherState)
public static void combine(@AggregationState QuantileDigestState state, @AggregationState QuantileDigestState otherState)
{
merge(state, otherState.getQuantileDigest());
}
Expand All @@ -122,8 +78,17 @@ private static void merge(QuantileDigestState state, QuantileDigest input)
}
}

public static void output(QuantileDigestStateSerializer serializer, QuantileDigestState state, BlockBuilder out)
@OutputFunction("qdigest(V)")
public static void output(
@TypeParameter("V") Type type,
@AggregationState QuantileDigestState state,
BlockBuilder out)
{
serializer.serialize(state, out);
if (state.getQuantileDigest() == null) {
out.appendNull();
}
else {
type.writeSlice(out, state.getQuantileDigest().serialize());
}
}
}
Loading

0 comments on commit ff24f6a

Please sign in to comment.