Skip to content

Commit

Permalink
Convert checksum aggregation to annotated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent 8f709f9 commit a5cdc4d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.functions(DECIMAL_TO_SMALLINT_SATURATED_FLOOR_CAST, SMALLINT_TO_DECIMAL_SATURATED_FLOOR_CAST)
.functions(DECIMAL_TO_TINYINT_SATURATED_FLOOR_CAST, TINYINT_TO_DECIMAL_SATURATED_FLOOR_CAST)
.function(new Histogram(blockTypeOperators))
.function(new ChecksumAggregationFunction(blockTypeOperators))
.aggregates(ChecksumAggregationFunction.class)
.function(ARBITRARY_AGGREGATION)
.functions(GREATEST, LEAST)
.functions(MAX_BY, MIN_BY, new MaxByNAggregationFunction(blockTypeOperators), new MinByNAggregationFunction(blockTypeOperators))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,96 +14,75 @@
package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.NullableLongState;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionXxHash64;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Convention;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.OperatorType;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static io.airlift.slice.Slices.wrappedLongArray;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;
import static java.util.Objects.requireNonNull;

public class ChecksumAggregationFunction
extends SqlAggregationFunction
@AggregationFunction("checksum")
@Description("Checksum of the given values")
public final class ChecksumAggregationFunction
{
@VisibleForTesting
public static final long PRIME64 = 0x9E3779B185EBCA87L;
private static final String NAME = "checksum";
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ChecksumAggregationFunction.class, "output", NullableLongState.class, BlockBuilder.class);
private static final MethodHandle INPUT_FUNCTION = methodHandle(ChecksumAggregationFunction.class, "input", BlockPositionXxHash64.class, NullableLongState.class, Block.class, int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(ChecksumAggregationFunction.class, "combine", NullableLongState.class, NullableLongState.class);

private final BlockTypeOperators blockTypeOperators;
private ChecksumAggregationFunction() {}

public ChecksumAggregationFunction(BlockTypeOperators blockTypeOperators)
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.comparableTypeParameter("T")
.returnType(VARBINARY)
.argumentType(new TypeSignature("T"))
.build())
.argumentNullability(true)
.description("Checksum of the given values")
.build(),
AggregationFunctionMetadata.builder()
.intermediateType(BIGINT)
.build());
this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null");
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
BlockPositionXxHash64 xxHash64Operator = blockTypeOperators.getXxHash64Operator(boundSignature.getArgumentTypes().get(0));
AccumulatorStateSerializer<NullableLongState> stateSerializer = StateCompiler.generateStateSerializer(NullableLongState.class);
return new AggregationMetadata(
INPUT_FUNCTION.bindTo(xxHash64Operator),
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
OUTPUT_FUNCTION,
ImmutableList.of(new AccumulatorStateDescriptor<>(
NullableLongState.class,
stateSerializer,
StateCompiler.generateStateFactory(NullableLongState.class))));
}

public static void input(BlockPositionXxHash64 xxHash64Operator, NullableLongState state, Block block, int position)
@InputFunction
@TypeParameter("T")
public static void input(
@OperatorDependency(
operator = OperatorType.XX_HASH_64,
argumentTypes = "T",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL))
MethodHandle xxHash64Operator,
@AggregationState NullableLongState state,
@NullablePosition @BlockPosition @SqlType("T") Block block,
@BlockIndex int position)
throws Throwable
{
state.setNull(false);
if (block.isNull(position)) {
state.setValue(state.getValue() + PRIME64);
}
else {
state.setValue(state.getValue() + xxHash64Operator.xxHash64(block, position) * PRIME64);
long valueHash = (long) xxHash64Operator.invokeExact(block, position);
state.setValue(state.getValue() + valueHash * PRIME64);
}
}

public static void combine(NullableLongState state, NullableLongState otherState)
@CombineFunction
public static void combine(
@AggregationState NullableLongState state,
@AggregationState NullableLongState otherState)
{
state.setNull(state.isNull() && otherState.isNull());
state.setValue(state.getValue() + otherState.getValue());
}

public static void output(NullableLongState state, BlockBuilder out)
@OutputFunction("VARBINARY")
public static void output(
@AggregationState NullableLongState state,
BlockBuilder out)
{
if (state.isNull()) {
out.appendNull();
Expand Down

0 comments on commit a5cdc4d

Please sign in to comment.