Skip to content

Commit

Permalink
Convert arbitrary aggregation to annotated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent d358c57 commit 15511fe
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.trino.operator.aggregation.ApproximateRealPercentileArrayAggregations;
import io.trino.operator.aggregation.ApproximateSetAggregation;
import io.trino.operator.aggregation.ApproximateSetGenericAggregation;
import io.trino.operator.aggregation.ArbitraryAggregationFunction;
import io.trino.operator.aggregation.AverageAggregations;
import io.trino.operator.aggregation.BigintApproximateMostFrequent;
import io.trino.operator.aggregation.BitwiseAndAggregation;
Expand Down Expand Up @@ -262,7 +263,6 @@
import io.trino.type.setdigest.SetDigestFunctions;
import io.trino.type.setdigest.SetDigestOperators;

import static io.trino.operator.aggregation.ArbitraryAggregationFunction.ARBITRARY_AGGREGATION;
import static io.trino.operator.aggregation.DecimalAverageAggregation.DECIMAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.DecimalSumAggregation.DECIMAL_SUM_AGGREGATION;
import static io.trino.operator.aggregation.MaxAggregationFunction.MAX_AGGREGATION;
Expand Down Expand Up @@ -542,7 +542,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.functions(DECIMAL_TO_TINYINT_SATURATED_FLOOR_CAST, TINYINT_TO_DECIMAL_SATURATED_FLOOR_CAST)
.function(new Histogram(blockTypeOperators))
.aggregates(ChecksumAggregationFunction.class)
.function(ARBITRARY_AGGREGATION)
.aggregates(ArbitraryAggregationFunction.class)
.functions(GREATEST, LEAST)
.functions(MAX_BY, MIN_BY, new MaxByNAggregationFunction(blockTypeOperators), new MinByNAggregationFunction(blockTypeOperators))
.functions(MAX_AGGREGATION, MIN_AGGREGATION, new MaxNAggregationFunction(blockTypeOperators), new MinNAggregationFunction(blockTypeOperators))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,192 +13,53 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.BlockPositionState;
import io.trino.operator.aggregation.state.BlockPositionStateSerializer;
import io.trino.operator.aggregation.state.GenericBooleanState;
import io.trino.operator.aggregation.state.GenericBooleanStateSerializer;
import io.trino.operator.aggregation.state.GenericDoubleState;
import io.trino.operator.aggregation.state.GenericDoubleStateSerializer;
import io.trino.operator.aggregation.state.GenericLongState;
import io.trino.operator.aggregation.state.GenericLongStateSerializer;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static io.trino.util.Reflection.methodHandle;

public class ArbitraryAggregationFunction
extends SqlAggregationFunction
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InOut;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;

@AggregationFunction("arbitrary")
@Description("Return an arbitrary non-null input value")
public final class ArbitraryAggregationFunction
{
public static final ArbitraryAggregationFunction ARBITRARY_AGGREGATION = new ArbitraryAggregationFunction();
private static final String NAME = "arbitrary";

private static final MethodHandle LONG_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, GenericLongState.class, Block.class, int.class);
private static final MethodHandle DOUBLE_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, GenericDoubleState.class, Block.class, int.class);
private static final MethodHandle BOOLEAN_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, GenericBooleanState.class, Block.class, int.class);
private static final MethodHandle BLOCK_POSITION_INPUT_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "input", Type.class, BlockPositionState.class, Block.class, int.class);

private static final MethodHandle LONG_OUTPUT_FUNCTION = methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class);
private static final MethodHandle DOUBLE_OUTPUT_FUNCTION = methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class);
private static final MethodHandle BOOLEAN_OUTPUT_FUNCTION = methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class);
private static final MethodHandle BLOCK_POSITION_OUTPUT_FUNCTION = methodHandle(BlockPositionState.class, "write", Type.class, BlockPositionState.class, BlockBuilder.class);

private static final MethodHandle LONG_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", GenericLongState.class, GenericLongState.class);
private static final MethodHandle DOUBLE_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", GenericDoubleState.class, GenericDoubleState.class);
private static final MethodHandle BOOLEAN_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", GenericBooleanState.class, GenericBooleanState.class);
private static final MethodHandle BLOCK_POSITION_COMBINE_FUNCTION = methodHandle(ArbitraryAggregationFunction.class, "combine", BlockPositionState.class, BlockPositionState.class);

protected ArbitraryAggregationFunction()
private ArbitraryAggregationFunction() {}

@InputFunction
@TypeParameter("T")
public static void input(
@AggregationState("T") InOut state,
@BlockPosition @SqlType("T") Block block,
@BlockIndex int position)
throws Throwable
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.typeVariable("T")
.returnType(new TypeSignature("T"))
.argumentType(new TypeSignature("T"))
.build())
.description("Return an arbitrary non-null input value")
.build(),
AggregationFunctionMetadata.builder()
.intermediateType(new TypeSignature("T"))
.build());
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
Type type = boundSignature.getReturnType();

MethodHandle inputFunction;
MethodHandle combineFunction;
MethodHandle outputFunction;
AccumulatorStateDescriptor<?> accumulatorStateDescriptor;

if (type.getJavaType() == long.class) {
accumulatorStateDescriptor = new AccumulatorStateDescriptor<>(
GenericLongState.class,
new GenericLongStateSerializer(type),
StateCompiler.generateStateFactory(GenericLongState.class));
inputFunction = LONG_INPUT_FUNCTION;
combineFunction = LONG_COMBINE_FUNCTION;
outputFunction = LONG_OUTPUT_FUNCTION;
}
else if (type.getJavaType() == double.class) {
accumulatorStateDescriptor = new AccumulatorStateDescriptor<>(
GenericDoubleState.class,
new GenericDoubleStateSerializer(type),
StateCompiler.generateStateFactory(GenericDoubleState.class));
inputFunction = DOUBLE_INPUT_FUNCTION;
combineFunction = DOUBLE_COMBINE_FUNCTION;
outputFunction = DOUBLE_OUTPUT_FUNCTION;
}
else if (type.getJavaType() == boolean.class) {
accumulatorStateDescriptor = new AccumulatorStateDescriptor<>(
GenericBooleanState.class,
new GenericBooleanStateSerializer(type),
StateCompiler.generateStateFactory(GenericBooleanState.class));
inputFunction = BOOLEAN_INPUT_FUNCTION;
combineFunction = BOOLEAN_COMBINE_FUNCTION;
outputFunction = BOOLEAN_OUTPUT_FUNCTION;
}
else {
// native container type is Slice or Block
accumulatorStateDescriptor = new AccumulatorStateDescriptor<>(
BlockPositionState.class,
new BlockPositionStateSerializer(type),
StateCompiler.generateStateFactory(BlockPositionState.class));
inputFunction = BLOCK_POSITION_INPUT_FUNCTION;
combineFunction = BLOCK_POSITION_COMBINE_FUNCTION;
outputFunction = BLOCK_POSITION_OUTPUT_FUNCTION;
}
inputFunction = inputFunction.bindTo(type);

return new AggregationMetadata(
inputFunction,
Optional.empty(),
Optional.of(combineFunction),
outputFunction.bindTo(type),
ImmutableList.of(accumulatorStateDescriptor));
}

public static void input(Type type, GenericDoubleState state, Block block, int position)
{
if (!state.isNull()) {
return;
if (state.isNull()) {
state.set(block, position);
}
state.setNull(false);
state.setValue(type.getDouble(block, position));
}

public static void input(Type type, GenericLongState state, Block block, int position)
@CombineFunction
public static void combine(
@AggregationState("T") InOut state,
@AggregationState("T") InOut otherState)
throws Throwable
{
if (!state.isNull()) {
return;
if (state.isNull()) {
state.set(otherState);
}
state.setNull(false);
state.setValue(type.getLong(block, position));
}

public static void input(Type type, GenericBooleanState state, Block block, int position)
@OutputFunction("T")
public static void output(@AggregationState("T") InOut state, BlockBuilder out)
{
if (!state.isNull()) {
return;
}
state.setNull(false);
state.setValue(type.getBoolean(block, position));
}

public static void input(Type type, BlockPositionState state, Block block, int position)
{
if (state.getBlock() != null) {
return;
}
state.setBlock(block);
state.setPosition(position);
}

public static void combine(GenericLongState state, GenericLongState otherState)
{
if (!state.isNull()) {
return;
}
state.set(otherState);
}

public static void combine(GenericDoubleState state, GenericDoubleState otherState)
{
if (!state.isNull()) {
return;
}
state.set(otherState);
}

public static void combine(GenericBooleanState state, GenericBooleanState otherState)
{
if (!state.isNull()) {
return;
}
state.set(otherState);
}

public static void combine(BlockPositionState state, BlockPositionState otherState)
{
if (state.getBlock() != null) {
return;
}
state.set(otherState);
state.get(out);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
approx_percentile | double | double, double | aggregate | true | |
approx_set | hyperloglog | bigint | aggregate | true | |
approx_set | hyperloglog | double | aggregate | true | |
arbitrary | T | T | aggregate | true | Return an arbitrary non-null input value |
arbitrary | t | t | aggregate | true | Return an arbitrary non-null input value |
asin | double | double | scalar | true | Arc sine |
atan | double | double | scalar | true | Arc tangent |
atan2 | double | double, double | scalar | true | Arc tangent of given fraction |
Expand Down

0 comments on commit 15511fe

Please sign in to comment.