Skip to content

Commit

Permalink
Convert decimal avg aggregation to annotated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent 230cf45 commit 5a76c23
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.trino.operator.aggregation.CountAggregation;
import io.trino.operator.aggregation.CountColumn;
import io.trino.operator.aggregation.CountIfAggregation;
import io.trino.operator.aggregation.DecimalAverageAggregation;
import io.trino.operator.aggregation.DecimalSumAggregation;
import io.trino.operator.aggregation.DefaultApproximateCountDistinctAggregation;
import io.trino.operator.aggregation.DoubleCorrelationAggregation;
Expand Down Expand Up @@ -270,7 +271,6 @@
import io.trino.type.setdigest.SetDigestFunctions;
import io.trino.type.setdigest.SetDigestOperators;

import static io.trino.operator.aggregation.DecimalAverageAggregation.DECIMAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG;
import static io.trino.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
import static io.trino.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR;
Expand Down Expand Up @@ -560,7 +560,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.function(DECIMAL_TO_DECIMAL_CAST)
.function(castVarcharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries()))
.function(castCharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries()))
.function(DECIMAL_AVERAGE_AGGREGATION)
.aggregates(DecimalAverageAggregation.class)
.aggregates(DecimalSumAggregation.class)
.function(DECIMAL_MOD_FUNCTION)
.functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,106 +14,50 @@
package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateSerializer;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.trino.spi.type.Decimals.overflows;
import static io.trino.spi.type.Decimals.writeShortDecimal;
import static io.trino.spi.type.Int128Math.addWithOverflow;
import static io.trino.spi.type.Int128Math.divideRoundUp;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;
import static java.math.BigDecimal.ROUND_HALF_UP;

public class DecimalAverageAggregation
extends SqlAggregationFunction
@AggregationFunction("avg")
@Description("Calculates the average value")
public final class DecimalAverageAggregation
{
public static final DecimalAverageAggregation DECIMAL_AVERAGE_AGGREGATION = new DecimalAverageAggregation();

private static final String NAME = "avg";
private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class);
private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", LongDecimalWithOverflowAndLongState.class, Block.class, int.class);

private static final MethodHandle SHORT_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "outputShortDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class);
private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = methodHandle(DecimalAverageAggregation.class, "outputLongDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class);

private static final MethodHandle COMBINE_FUNCTION = methodHandle(DecimalAverageAggregation.class, "combine", LongDecimalWithOverflowAndLongState.class, LongDecimalWithOverflowAndLongState.class);

private static final BigInteger TWO = new BigInteger("2");
private static final BigInteger OVERFLOW_MULTIPLIER = TWO.pow(128);

public DecimalAverageAggregation()
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.returnType(new TypeSignature("decimal", typeVariable("p"), typeVariable("s")))
.argumentType(new TypeSignature("decimal", typeVariable("p"), typeVariable("s")))
.build())
.description("Calculates the average value")
.build(),
AggregationFunctionMetadata.builder()
.intermediateType(VARBINARY)
.build());
}
private DecimalAverageAggregation() {}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
Type type = getOnlyElement(boundSignature.getArgumentTypes());
checkArgument(type instanceof DecimalType, "type must be Decimal");
MethodHandle inputFunction;
MethodHandle outputFunction;
Class<LongDecimalWithOverflowAndLongState> stateInterface = LongDecimalWithOverflowAndLongState.class;
LongDecimalWithOverflowAndLongStateSerializer stateSerializer = new LongDecimalWithOverflowAndLongStateSerializer();

if (((DecimalType) type).isShort()) {
inputFunction = SHORT_DECIMAL_INPUT_FUNCTION;
outputFunction = SHORT_DECIMAL_OUTPUT_FUNCTION;
}
else {
inputFunction = LONG_DECIMAL_INPUT_FUNCTION;
outputFunction = LONG_DECIMAL_OUTPUT_FUNCTION;
}
outputFunction = outputFunction.bindTo(type);

return new AggregationMetadata(
inputFunction,
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
outputFunction,
ImmutableList.of(new AccumulatorStateDescriptor<>(
stateInterface,
stateSerializer,
new LongDecimalWithOverflowAndLongStateFactory())));
}

public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state, Block block, int position)
@InputFunction
@LiteralParameters({"p", "s"})
public static void inputShortDecimal(
@AggregationState LongDecimalWithOverflowAndLongState state,
@BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = long.class) Block block,
@BlockIndex int position)
{
state.addLong(1); // row counter

Expand All @@ -136,7 +80,12 @@ public static void inputShortDecimal(LongDecimalWithOverflowAndLongState state,
state.addOverflow(overflow);
}

public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, Block block, int position)
@InputFunction
@LiteralParameters({"p", "s"})
public static void inputLongDecimal(
@AggregationState LongDecimalWithOverflowAndLongState state,
@BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = Int128.class) Block block,
@BlockIndex int position)
{
state.addLong(1); // row counter

Expand All @@ -159,7 +108,8 @@ public static void inputLongDecimal(LongDecimalWithOverflowAndLongState state, B
state.addOverflow(overflow);
}

public static void combine(LongDecimalWithOverflowAndLongState state, LongDecimalWithOverflowAndLongState otherState)
@CombineFunction
public static void combine(@AggregationState LongDecimalWithOverflowAndLongState state, @AggregationState LongDecimalWithOverflowAndLongState otherState)
{
state.addLong(otherState.getLong()); // row counter

Expand Down Expand Up @@ -187,23 +137,23 @@ public static void combine(LongDecimalWithOverflowAndLongState state, LongDecima
}
}

public static void outputShortDecimal(DecimalType type, LongDecimalWithOverflowAndLongState state, BlockBuilder out)
@OutputFunction("decimal(p,s)")
public static void outputShortDecimal(
@TypeParameter("decimal(p,s)") Type type,
@AggregationState LongDecimalWithOverflowAndLongState state,
BlockBuilder out)
{
DecimalType decimalType = (DecimalType) type;
if (state.getLong() == 0) {
out.appendNull();
return;
}
else {
writeShortDecimal(out, average(state, type).toLongExact());
}
}

public static void outputLongDecimal(DecimalType type, LongDecimalWithOverflowAndLongState state, BlockBuilder out)
{
if (state.getLong() == 0) {
out.appendNull();
Int128 average = average(state, decimalType);
if (decimalType.isShort()) {
writeShortDecimal(out, average.toLongExact());
}
else {
type.writeObject(out, average(state, type));
type.writeObject(out, average);
}
}

Expand Down

0 comments on commit 5a76c23

Please sign in to comment.