Skip to content

Commit

Permalink
Convert count aggregation to annotated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent 760a54b commit 8f709f9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.trino.operator.aggregation.CentralMomentsAggregation;
import io.trino.operator.aggregation.ChecksumAggregationFunction;
import io.trino.operator.aggregation.CountAggregation;
import io.trino.operator.aggregation.CountColumn;
import io.trino.operator.aggregation.CountIfAggregation;
import io.trino.operator.aggregation.DefaultApproximateCountDistinctAggregation;
import io.trino.operator.aggregation.DoubleCorrelationAggregation;
Expand Down Expand Up @@ -258,7 +259,6 @@
import io.trino.type.setdigest.SetDigestOperators;

import static io.trino.operator.aggregation.ArbitraryAggregationFunction.ARBITRARY_AGGREGATION;
import static io.trino.operator.aggregation.CountColumn.COUNT_COLUMN;
import static io.trino.operator.aggregation.DecimalAverageAggregation.DECIMAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.DecimalSumAggregation.DECIMAL_SUM_AGGREGATION;
import static io.trino.operator.aggregation.MaxAggregationFunction.MAX_AGGREGATION;
Expand Down Expand Up @@ -544,7 +544,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.functions(GREATEST, LEAST)
.functions(MAX_BY, MIN_BY, new MaxByNAggregationFunction(blockTypeOperators), new MinByNAggregationFunction(blockTypeOperators))
.functions(MAX_AGGREGATION, MIN_AGGREGATION, new MaxNAggregationFunction(blockTypeOperators), new MinNAggregationFunction(blockTypeOperators))
.function(COUNT_COLUMN)
.aggregates(CountColumn.class)
.functions(JSON_TO_ROW, JSON_STRING_TO_ROW, ROW_TO_ROW_CAST)
.functions(VARCHAR_CONCAT, VARBINARY_CONCAT)
.function(CONCAT_WS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,87 +13,56 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.LongState;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.util.Optional;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.RemoveInputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.util.Reflection.methodHandle;

public class CountColumn
extends SqlAggregationFunction
@AggregationFunction("count")
@Description("Counts the non-null values")
public final class CountColumn
{
public static final CountColumn COUNT_COLUMN = new CountColumn();
private static final String NAME = "count";
private static final MethodHandle INPUT_FUNCTION = methodHandle(CountColumn.class, "input", LongState.class, Block.class, int.class);
private static final MethodHandle REMOVE_INPUT_FUNCTION = methodHandle(CountColumn.class, "removeInput", LongState.class, Block.class, int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(CountColumn.class, "combine", LongState.class, LongState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(CountColumn.class, "output", LongState.class, BlockBuilder.class);

public CountColumn()
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.typeVariable("T")
.returnType(BIGINT)
.argumentType(new TypeSignature("T"))
.build())
.description("Counts the non-null values")
.build(),
AggregationFunctionMetadata.builder()
.intermediateType(BIGINT)
.build());
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
AccumulatorStateSerializer<LongState> stateSerializer = StateCompiler.generateStateSerializer(LongState.class);
AccumulatorStateFactory<LongState> stateFactory = StateCompiler.generateStateFactory(LongState.class);

return new AggregationMetadata(
INPUT_FUNCTION,
Optional.of(REMOVE_INPUT_FUNCTION),
Optional.of(COMBINE_FUNCTION),
OUTPUT_FUNCTION,
ImmutableList.of(new AccumulatorStateDescriptor<>(
LongState.class,
stateSerializer,
stateFactory)));
}
private CountColumn() {}

public static void input(LongState state, Block block, int index)
@InputFunction
@TypeParameter("T")
public static void input(
@AggregationState LongState state,
@BlockPosition @SqlType("T") Block block,
@BlockIndex int position)
{
state.setValue(state.getValue() + 1);
}

public static void removeInput(LongState state, Block block, int index)
@RemoveInputFunction
public static void removeInput(
@AggregationState LongState state,
@BlockPosition @SqlType("T") Block block,
@BlockIndex int position)
{
state.setValue(state.getValue() - 1);
}

public static void combine(LongState state, LongState otherState)
@CombineFunction
public static void combine(@AggregationState LongState state, LongState otherState)
{
state.setValue(state.getValue() + otherState.getValue());
}

public static void output(LongState state, BlockBuilder out)
@OutputFunction("BIGINT")
public static void output(@AggregationState LongState state, BlockBuilder out)
{
BIGINT.writeLong(out, state.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public void testTypeCombinations()
@Test
public void testFunctionParameter()
{
assertInvalidFunction("count(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>) for function count. Expected: count(), count(T) T");
assertInvalidFunction("count(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>) for function count. Expected: count(), count(t) T");
assertInvalidFunction("max(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>) for function max. Expected: max(E) E:orderable, max(E, bigint) E:orderable");
assertInvalidFunction("sqrt(x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>) for function sqrt. Expected: sqrt(double)");
assertInvalidFunction("sqrt(x -> x, 123, x -> x)", FUNCTION_NOT_FOUND, "line 1:1: Unexpected parameters (<function>, integer, <function>) for function sqrt. Expected: sqrt(double)");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
cos | double | double | scalar | true | Cosine |
cosh | double | double | scalar | true | Hyperbolic cosine |
count | bigint | | aggregate | true | |
count | bigint | T | aggregate | true | Counts the non-null values |
count | bigint | t | aggregate | true | Counts the non-null values |
count_if | bigint | boolean | aggregate | true | |
covar_pop | double | double, double | aggregate | true | |
covar_samp | double | double, double | aggregate | true | |
Expand Down

0 comments on commit 8f709f9

Please sign in to comment.