Skip to content

Commit

Permalink
Convert avg(REAL) aggregation to annotated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent 1ca7383 commit 39dfb23
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.BigintQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.DoubleQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.RealQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.RealAverageAggregation;
import io.trino.operator.aggregation.RealCorrelationAggregation;
import io.trino.operator.aggregation.RealCovarianceAggregation;
import io.trino.operator.aggregation.RealGeometricMeanAggregations;
Expand Down Expand Up @@ -269,7 +270,6 @@

import static io.trino.operator.aggregation.DecimalAverageAggregation.DECIMAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.DecimalSumAggregation.DECIMAL_SUM_AGGREGATION;
import static io.trino.operator.aggregation.RealAverageAggregation.REAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG;
import static io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction.ARRAY_AGG;
import static io.trino.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
Expand Down Expand Up @@ -393,7 +393,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.aggregates(IntervalDayToSecondSumAggregation.class)
.aggregates(IntervalYearToMonthSumAggregation.class)
.aggregates(AverageAggregations.class)
.function(REAL_AVERAGE_AGGREGATION)
.aggregates(RealAverageAggregation.class)
.aggregates(IntervalDayToSecondAverageAggregation.class)
.aggregates(IntervalYearToMonthAverageAggregation.class)
.aggregates(GeometricMeanAggregations.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,106 +13,64 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.DoubleState;
import io.trino.operator.aggregation.state.LongState;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.RemoveInputFunction;
import io.trino.spi.function.SqlType;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.Float.floatToIntBits;
import static java.lang.Float.intBitsToFloat;

public class RealAverageAggregation
extends SqlAggregationFunction
@AggregationFunction("avg")
@Description("Returns the average value of the argument")
public final class RealAverageAggregation
{
public static final RealAverageAggregation REAL_AVERAGE_AGGREGATION = new RealAverageAggregation();
private static final String NAME = "avg";

private static final MethodHandle INPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "input", LongState.class, DoubleState.class, long.class);
private static final MethodHandle REMOVE_INPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "removeInput", LongState.class, DoubleState.class, long.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(RealAverageAggregation.class, "combine", LongState.class, DoubleState.class, LongState.class, DoubleState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(RealAverageAggregation.class, "output", LongState.class, DoubleState.class, BlockBuilder.class);

protected RealAverageAggregation()
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.returnType(REAL)
.argumentType(REAL)
.build())
.description("Returns the average value of the argument")
.build(),
AggregationFunctionMetadata.builder()
.intermediateType(BIGINT)
.intermediateType(DOUBLE)
.build());
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
Class<LongState> longStateInterface = LongState.class;
Class<DoubleState> doubleStateInterface = DoubleState.class;
AccumulatorStateSerializer<LongState> longStateSerializer = StateCompiler.generateStateSerializer(longStateInterface);
AccumulatorStateSerializer<DoubleState> doubleStateSerializer = StateCompiler.generateStateSerializer(doubleStateInterface);

MethodHandle inputFunction = normalizeInputMethod(INPUT_FUNCTION, boundSignature, STATE, STATE, INPUT_CHANNEL);
MethodHandle removeFunction = normalizeInputMethod(REMOVE_INPUT_FUNCTION, boundSignature, STATE, STATE, INPUT_CHANNEL);

return new AggregationMetadata(
inputFunction,
Optional.of(removeFunction),
Optional.of(COMBINE_FUNCTION),
OUTPUT_FUNCTION,
ImmutableList.of(
new AccumulatorStateDescriptor<>(
longStateInterface,
longStateSerializer,
StateCompiler.generateStateFactory(longStateInterface)),
new AccumulatorStateDescriptor<>(
doubleStateInterface,
doubleStateSerializer,
StateCompiler.generateStateFactory(doubleStateInterface))));
}
private RealAverageAggregation() {}

public static void input(LongState count, DoubleState sum, long value)
@InputFunction
public static void input(
@AggregationState LongState count,
@AggregationState DoubleState sum,
@SqlType("REAL") long value)
{
count.setValue(count.getValue() + 1);
sum.setValue(sum.getValue() + intBitsToFloat((int) value));
}

public static void removeInput(LongState count, DoubleState sum, long value)
@RemoveInputFunction
public static void removeInput(
@AggregationState LongState count,
@AggregationState DoubleState sum,
@SqlType("REAL") long value)
{
count.setValue(count.getValue() - 1);
sum.setValue(sum.getValue() - intBitsToFloat((int) value));
}

public static void combine(LongState count, DoubleState sum, LongState otherCount, DoubleState otherSum)
@CombineFunction
public static void combine(
@AggregationState LongState count,
@AggregationState DoubleState sum,
@AggregationState LongState otherCount,
@AggregationState DoubleState otherSum)
{
count.setValue(count.getValue() + otherCount.getValue());
sum.setValue(sum.getValue() + otherSum.getValue());
}

public static void output(LongState count, DoubleState sum, BlockBuilder out)
@OutputFunction("REAL")
public static void output(
@AggregationState LongState count,
@AggregationState DoubleState sum,
BlockBuilder out)
{
if (count.getValue() == 0) {
out.appendNull();
Expand Down

0 comments on commit 39dfb23

Please sign in to comment.