Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite most aggregations as annotated functions #11477

Merged
merged 27 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
dc3529f
Do not apply SingleDistinctAggregationToGroupBy to ordering functions
dain May 15, 2022
d7d8b61
Fix assertion test in AggregationTestUtils
dain Dec 26, 2021
866a12c
Move TestStateCompiler to state package and clean up
dain Dec 24, 2021
10959a8
Convert count aggregation to annotated function
dain Dec 29, 2021
19faa04
Convert checksum aggregation to annotated function
dain Dec 29, 2021
c9ef201
Fix null parameter declaration in listagg
dain Dec 31, 2021
dce2219
Convert listagg aggregation to annotated function
dain Dec 31, 2021
d75b6c9
Convert qdigest_agg and merge aggregations to annotated functions
dain Dec 31, 2021
c5f1259
Add in-out calling convention
dain Dec 24, 2021
43265e5
Add support for generic in-out state to annotated aggregation functions
dain Dec 24, 2021
def6e8e
Convert arbitrary aggregation to annotated function
dain Dec 28, 2021
6e38de1
Convert min/max aggregations to annotate functions
dain Dec 24, 2021
16a742c
Add support for multiple state variables in annotated aggregations
dain Dec 26, 2021
797b442
Convert min/max_by aggregation to annotated function
dain Dec 26, 2021
3f297d4
Remove unused typed constructors from nullable aggregation states
dain May 10, 2022
d96547f
Convert avg(REAL) aggregation to annotated function
dain Dec 29, 2021
7e1e6bd
Add support for generic aggregation state classes
dain Dec 30, 2021
42259d4
Convert min/max N to annotated aggregate function
dain Dec 30, 2021
1f13732
Convert min_by/max_by N to annotated aggregate function
dain Dec 31, 2021
50d0a41
Convert array_agg aggregation to annotated function
dain Dec 31, 2021
cebed5c
Convert map_agg and map_union aggregations to annotated functions
dain Dec 31, 2021
d26754c
Convert multimap_agg aggregation to annotated function
dain Dec 31, 2021
a5c2f1c
Convert histogram aggregation to annotated function
dain Dec 31, 2021
d74f39f
Add support for specialized BlockPosition in annotated aggergations
dain Dec 29, 2021
19b0b3f
Convert decimal sum aggregation to annotated function
dain Dec 29, 2021
ab834ac
Convert decimal avg aggregation to annotated function
dain Dec 31, 2021
7c24aa0
Add error message for invalid aggregation class
dain May 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.InOut;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention;
import io.trino.spi.type.Type;
Expand Down Expand Up @@ -196,7 +197,10 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, F
break;
case BLOCK_POSITION:
verifyFunctionSignature(parameterType.equals(Block.class) && methodType.parameterType(parameterIndex + 1).equals(int.class),
"Expected BLOCK_POSITION argument have parameters Block and int");
"Expected BLOCK_POSITION argument types to be Block and int");
break;
case IN_OUT:
verifyFunctionSignature(parameterType.equals(InOut.class), "Expected IN_OUT argument type to be InOut");
break;
case FUNCTION:
Class<?> lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ private static boolean matchesParameterAndReturnTypes(
}
actualType = resolvedType.getJavaType();
break;
case IN_OUT:
// any type is supported, so just ignore this check
actualType = resolvedType.getJavaType();
expectedType = resolvedType.getJavaType();
break;
default:
throw new UnsupportedOperationException("Unknown argument convention: " + argumentConvention);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ public abstract class SqlAggregationFunction

public static List<SqlAggregationFunction> createFunctionsByAnnotations(Class<?> aggregationDefinition)
{
return ImmutableList.copyOf(AggregationFromAnnotationsParser.parseFunctionDefinitions(aggregationDefinition));
try {
return ImmutableList.copyOf(AggregationFromAnnotationsParser.parseFunctionDefinitions(aggregationDefinition));
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Invalid aggregation class " + aggregationDefinition.getSimpleName());
}
}

public SqlAggregationFunction(FunctionMetadata functionMetadata, AggregationFunctionMetadata aggregationFunctionMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.trino.operator.aggregation.ApproximateRealPercentileArrayAggregations;
import io.trino.operator.aggregation.ApproximateSetAggregation;
import io.trino.operator.aggregation.ApproximateSetGenericAggregation;
import io.trino.operator.aggregation.ArbitraryAggregationFunction;
import io.trino.operator.aggregation.AverageAggregations;
import io.trino.operator.aggregation.BigintApproximateMostFrequent;
import io.trino.operator.aggregation.BitwiseAndAggregation;
Expand All @@ -36,7 +37,10 @@
import io.trino.operator.aggregation.CentralMomentsAggregation;
import io.trino.operator.aggregation.ChecksumAggregationFunction;
import io.trino.operator.aggregation.CountAggregation;
import io.trino.operator.aggregation.CountColumn;
import io.trino.operator.aggregation.CountIfAggregation;
import io.trino.operator.aggregation.DecimalAverageAggregation;
import io.trino.operator.aggregation.DecimalSumAggregation;
import io.trino.operator.aggregation.DefaultApproximateCountDistinctAggregation;
import io.trino.operator.aggregation.DoubleCorrelationAggregation;
import io.trino.operator.aggregation.DoubleCovarianceAggregation;
Expand All @@ -54,12 +58,18 @@
import io.trino.operator.aggregation.LongSumAggregation;
import io.trino.operator.aggregation.MapAggregationFunction;
import io.trino.operator.aggregation.MapUnionAggregation;
import io.trino.operator.aggregation.MaxAggregationFunction;
import io.trino.operator.aggregation.MaxByAggregationFunction;
import io.trino.operator.aggregation.MaxDataSizeForStats;
import io.trino.operator.aggregation.MaxNAggregationFunction;
import io.trino.operator.aggregation.MergeHyperLogLogAggregation;
import io.trino.operator.aggregation.MergeQuantileDigestFunction;
import io.trino.operator.aggregation.MergeTDigestAggregation;
import io.trino.operator.aggregation.MinNAggregationFunction;
import io.trino.operator.aggregation.MinAggregationFunction;
import io.trino.operator.aggregation.MinByAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.BigintQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.DoubleQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.QuantileDigestAggregationFunction.RealQuantileDigestAggregationFunction;
import io.trino.operator.aggregation.RealAverageAggregation;
import io.trino.operator.aggregation.RealCorrelationAggregation;
import io.trino.operator.aggregation.RealCovarianceAggregation;
import io.trino.operator.aggregation.RealGeometricMeanAggregations;
Expand All @@ -70,9 +80,13 @@
import io.trino.operator.aggregation.TDigestAggregationFunction;
import io.trino.operator.aggregation.VarcharApproximateMostFrequent;
import io.trino.operator.aggregation.VarianceAggregation;
import io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction;
import io.trino.operator.aggregation.histogram.Histogram;
import io.trino.operator.aggregation.minmaxby.MaxByNAggregationFunction;
import io.trino.operator.aggregation.minmaxby.MinByNAggregationFunction;
import io.trino.operator.aggregation.listagg.ListaggAggregationFunction;
import io.trino.operator.aggregation.minmaxbyn.MaxByNAggregationFunction;
import io.trino.operator.aggregation.minmaxbyn.MinByNAggregationFunction;
import io.trino.operator.aggregation.minmaxn.MaxNAggregationFunction;
import io.trino.operator.aggregation.minmaxn.MinNAggregationFunction;
import io.trino.operator.aggregation.multimapagg.MultimapAggregationFunction;
import io.trino.operator.scalar.ArrayAllMatchFunction;
import io.trino.operator.scalar.ArrayAnyMatchFunction;
Expand Down Expand Up @@ -255,21 +269,7 @@
import io.trino.type.setdigest.SetDigestFunctions;
import io.trino.type.setdigest.SetDigestOperators;

import static io.trino.operator.aggregation.ArbitraryAggregationFunction.ARBITRARY_AGGREGATION;
import static io.trino.operator.aggregation.CountColumn.COUNT_COLUMN;
import static io.trino.operator.aggregation.DecimalAverageAggregation.DECIMAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.DecimalSumAggregation.DECIMAL_SUM_AGGREGATION;
import static io.trino.operator.aggregation.MaxAggregationFunction.MAX_AGGREGATION;
import static io.trino.operator.aggregation.MinAggregationFunction.MIN_AGGREGATION;
import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG;
import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG_WITH_WEIGHT;
import static io.trino.operator.aggregation.QuantileDigestAggregationFunction.QDIGEST_AGG_WITH_WEIGHT_AND_ERROR;
import static io.trino.operator.aggregation.RealAverageAggregation.REAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG;
import static io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction.ARRAY_AGG;
import static io.trino.operator.aggregation.listagg.ListaggAggregationFunction.LISTAGG;
import static io.trino.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY;
import static io.trino.operator.aggregation.minmaxby.MinByAggregationFunction.MIN_BY;
import static io.trino.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
import static io.trino.operator.scalar.ArrayConstructor.ARRAY_CONSTRUCTOR;
import static io.trino.operator.scalar.ArrayFlattenFunction.ARRAY_FLATTEN_FUNCTION;
Expand Down Expand Up @@ -391,7 +391,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.aggregates(IntervalDayToSecondSumAggregation.class)
.aggregates(IntervalYearToMonthSumAggregation.class)
.aggregates(AverageAggregations.class)
.function(REAL_AVERAGE_AGGREGATION)
.aggregates(RealAverageAggregation.class)
.aggregates(IntervalDayToSecondAverageAggregation.class)
.aggregates(IntervalYearToMonthAverageAggregation.class)
.aggregates(GeometricMeanAggregations.class)
Expand All @@ -400,8 +400,10 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.aggregates(ApproximateSetAggregation.class)
.aggregates(ApproximateSetGenericAggregation.class)
.aggregates(TDigestAggregationFunction.class)
.functions(QDIGEST_AGG, QDIGEST_AGG_WITH_WEIGHT, QDIGEST_AGG_WITH_WEIGHT_AND_ERROR)
.function(MergeQuantileDigestFunction.MERGE)
.aggregates(DoubleQuantileDigestAggregationFunction.class)
.aggregates(RealQuantileDigestAggregationFunction.class)
.aggregates(BigintQuantileDigestAggregationFunction.class)
.aggregates(MergeQuantileDigestFunction.class)
.aggregates(MergeTDigestAggregation.class)
.aggregates(DoubleHistogramAggregation.class)
.aggregates(RealHistogramAggregation.class)
Expand Down Expand Up @@ -518,13 +520,14 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.function(ARRAY_FLATTEN_FUNCTION)
.function(ARRAY_CONCAT_FUNCTION)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, JSON_TO_ARRAY, JSON_STRING_TO_ARRAY)
.function(ARRAY_AGG)
.function(LISTAGG)
.aggregates(ArrayAggregationFunction.class)
.aggregates(ListaggAggregationFunction.class)
.functions(new MapSubscriptOperator())
.functions(MAP_CONSTRUCTOR, JSON_TO_MAP, JSON_STRING_TO_MAP)
.functions(new MapAggregationFunction(blockTypeOperators), new MapUnionAggregation(blockTypeOperators))
.aggregates(MapAggregationFunction.class)
.aggregates(MapUnionAggregation.class)
.function(REDUCE_AGG)
.function(new MultimapAggregationFunction(blockTypeOperators))
.aggregates(MultimapAggregationFunction.class)
.functions(DECIMAL_TO_VARCHAR_CAST, DECIMAL_TO_INTEGER_CAST, DECIMAL_TO_BIGINT_CAST, DECIMAL_TO_DOUBLE_CAST, DECIMAL_TO_REAL_CAST, DECIMAL_TO_BOOLEAN_CAST, DECIMAL_TO_TINYINT_CAST, DECIMAL_TO_SMALLINT_CAST)
.functions(VARCHAR_TO_DECIMAL_CAST, INTEGER_TO_DECIMAL_CAST, BIGINT_TO_DECIMAL_CAST, DOUBLE_TO_DECIMAL_CAST, REAL_TO_DECIMAL_CAST, BOOLEAN_TO_DECIMAL_CAST, TINYINT_TO_DECIMAL_CAST, SMALLINT_TO_DECIMAL_CAST)
.functions(JSON_TO_DECIMAL_CAST, DECIMAL_TO_JSON_CAST)
Expand All @@ -534,21 +537,27 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.functions(DECIMAL_TO_INTEGER_SATURATED_FLOOR_CAST, INTEGER_TO_DECIMAL_SATURATED_FLOOR_CAST)
.functions(DECIMAL_TO_SMALLINT_SATURATED_FLOOR_CAST, SMALLINT_TO_DECIMAL_SATURATED_FLOOR_CAST)
.functions(DECIMAL_TO_TINYINT_SATURATED_FLOOR_CAST, TINYINT_TO_DECIMAL_SATURATED_FLOOR_CAST)
.function(new Histogram(blockTypeOperators))
.function(new ChecksumAggregationFunction(blockTypeOperators))
.function(ARBITRARY_AGGREGATION)
.aggregates(Histogram.class)
.aggregates(ChecksumAggregationFunction.class)
.aggregates(ArbitraryAggregationFunction.class)
.functions(GREATEST, LEAST)
.functions(MAX_BY, MIN_BY, new MaxByNAggregationFunction(blockTypeOperators), new MinByNAggregationFunction(blockTypeOperators))
.functions(MAX_AGGREGATION, MIN_AGGREGATION, new MaxNAggregationFunction(blockTypeOperators), new MinNAggregationFunction(blockTypeOperators))
.function(COUNT_COLUMN)
.aggregates(MinAggregationFunction.class)
.aggregates(MaxAggregationFunction.class)
.aggregates(MinByAggregationFunction.class)
.aggregates(MaxByAggregationFunction.class)
.aggregates(MaxNAggregationFunction.class)
.aggregates(MinNAggregationFunction.class)
.aggregates(MinByNAggregationFunction.class)
.aggregates(MaxByNAggregationFunction.class)
.aggregates(CountColumn.class)
.functions(JSON_TO_ROW, JSON_STRING_TO_ROW, ROW_TO_ROW_CAST)
.functions(VARCHAR_CONCAT, VARBINARY_CONCAT)
.function(CONCAT_WS)
.function(DECIMAL_TO_DECIMAL_CAST)
.function(castVarcharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries()))
.function(castCharToRe2JRegexp(featuresConfig.getRe2JDfaStatesLimit(), featuresConfig.getRe2JDfaRetries()))
.function(DECIMAL_AVERAGE_AGGREGATION)
.function(DECIMAL_SUM_AGGREGATION)
.aggregates(DecimalAverageAggregation.class)
.aggregates(DecimalSumAggregation.class)
.function(DECIMAL_MOD_FUNCTION)
.functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION)
.functions(MAP_FILTER_FUNCTION, new MapTransformKeysFunction(blockTypeOperators), MAP_TRANSFORM_VALUES_FUNCTION)
Expand Down
Loading