Skip to content

Commit

Permalink
Add support for multiple state variables in annotated aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent bdbfc01 commit b8d6287
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation;
import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;

public final class AggregationFromAnnotationsParser
Expand All @@ -67,8 +67,8 @@ public static List<ParametricAggregation> parseFunctionDefinitions(Class<?> aggr

ImmutableList.Builder<ParametricAggregation> functions = ImmutableList.builder();

// There must be a single state class and combine function
AccumulatorStateDetails stateDetails = getStateDetails(aggregationDefinition);
// There must be a single set of state classes and a single combine function
List<AccumulatorStateDetails> stateDetails = getStateDetails(aggregationDefinition);
Optional<Method> combineFunction = getCombineFunction(aggregationDefinition, stateDetails);

// Each output function defines a new aggregation function
Expand Down Expand Up @@ -114,7 +114,7 @@ else if (combineFunction.isPresent()) {
private static List<ParametricAggregation> buildFunctions(
String name,
AggregationHeader header,
AccumulatorStateDetails stateDetails,
List<AccumulatorStateDetails> stateDetails,
List<AggregationImplementation> exactImplementations,
List<AggregationImplementation> nonExactImplementations)
{
Expand Down Expand Up @@ -183,48 +183,95 @@ private static List<String> getAliases(AggregationFunction aggregationAnnotation
return ImmutableList.copyOf(aggregationAnnotation.alias());
}

private static Optional<Method> getCombineFunction(Class<?> clazz, AccumulatorStateDetails stateDetails)
private static Optional<Method> getCombineFunction(Class<?> clazz, List<AccumulatorStateDetails> stateDetails)
{
List<Method> combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class);
for (Method combineFunction : combineFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(combineFunction);
List<Class<?>> expectedParameterTypes = nCopies(2, stateDetails.getStateClass());
checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction);
if (combineFunctions.isEmpty()) {
return Optional.empty();
}
checkArgument(combineFunctions.size() == 1, "There must be only one @CombineFunction in class %s", clazz.toGenericString());
Method combineFunction = getOnlyElement(combineFunctions);

// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(combineFunction);
List<Class<?>> expectedParameterTypes = Stream.concat(stateDetails.stream(), stateDetails.stream())
.map(AccumulatorStateDetails::getStateClass)
.collect(toImmutableList());
checkArgument(parameterTypes.equals(expectedParameterTypes),
"Expected combine function non-dependency parameters to be %s: %s",
expectedParameterTypes,
combineFunction);

// legacy combine functions did not require parameters to be fully annotated
if (stateDetails.size() > 1) {
List<List<Annotation>> parameterAnnotations = getNonDependencyParameterAnnotations(combineFunction);
List<AccumulatorStateDetails> actualStateDetails = new ArrayList<>();
for (int parameterIndex = 0; parameterIndex < parameterTypes.size(); parameterIndex++) {
actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex), parameterAnnotations.get(parameterIndex), combineFunction, true));
}
List<AccumulatorStateDetails> expectedStateDetails = ImmutableList.<AccumulatorStateDetails>builder().addAll(stateDetails).addAll(stateDetails).build();
checkArgument(actualStateDetails.equals(expectedStateDetails), "Expected combine function to have state parameters %s, but has %s", stateDetails, expectedStateDetails);
}
checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateDetails.getStateClass().toGenericString());
return combineFunctions.stream().findFirst();
return Optional.of(combineFunction);
}

private static List<Method> getOutputFunctions(Class<?> clazz, AccumulatorStateDetails stateDetails)
private static List<Method> getOutputFunctions(Class<?> clazz, List<AccumulatorStateDetails> stateDetails)
{
List<Method> outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class);
for (Method outputFunction : outputFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(outputFunction);
List<Class<?>> expectedParameterTypes = ImmutableList.<Class<?>>builder()
.add(stateDetails.getStateClass())
.addAll(stateDetails.stream().map(AccumulatorStateDetails::getStateClass).collect(toImmutableList()))
.add(BlockBuilder.class)
.build();
checkArgument(parameterTypes.equals(expectedParameterTypes),
"Expected output function non-dependency parameters to be %s: %s",
expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()),
outputFunction);

// legacy output functions did not require parameters to be fully annotated
if (stateDetails.size() > 1) {
List<List<Annotation>> parameterAnnotations = getNonDependencyParameterAnnotations(outputFunction);

List<AccumulatorStateDetails> actualStateDetails = new ArrayList<>();
for (int parameterIndex = 0; parameterIndex < stateDetails.size(); parameterIndex++) {
actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex), parameterAnnotations.get(parameterIndex), outputFunction, true));
}
checkArgument(actualStateDetails.equals(stateDetails), "Expected output function to have state parameters %s, but has %s", stateDetails, actualStateDetails);
}
}
checkArgument(!outputFunctions.isEmpty(), "Aggregation has no output functions");
return outputFunctions;
}

private static List<Method> getInputFunctions(Class<?> clazz, AccumulatorStateDetails stateDetails)
private static List<Method> getInputFunctions(Class<?> clazz, List<AccumulatorStateDetails> stateDetails)
{
List<Method> inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class);
for (Method inputFunction : inputFunctions) {
// verify state parameter is first non-dependency parameter
Class<?> actualStateType = getNonDependencyParameterTypes(inputFunction).get(0);
checkArgument(stateDetails.getStateClass().equals(actualStateType),
"Expected input function non-dependency parameters to begin with state type %s: %s",
stateDetails.getStateClass().getSimpleName(),
// verify state parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(inputFunction)
.subList(0, stateDetails.size());
List<Class<?>> expectedParameterTypes = ImmutableList.<Class<?>>builder()
.addAll(stateDetails.stream().map(AccumulatorStateDetails::getStateClass).collect(toImmutableList()))
.build()
.subList(0, stateDetails.size());
checkArgument(parameterTypes.equals(expectedParameterTypes),
"Expected input function non-dependency parameters to begin with state types %s: %s",
expectedParameterTypes.stream().map(Class::getSimpleName).collect(toImmutableList()),
inputFunction);

// g input functions did not require parameters to be fully annotated
if (stateDetails.size() > 1) {
List<List<Annotation>> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction)
.subList(0, stateDetails.size());

List<AccumulatorStateDetails> actualStateDetails = new ArrayList<>();
for (int parameterIndex = 0; parameterIndex < stateDetails.size(); parameterIndex++) {
actualStateDetails.add(toAccumulatorStateDetails(parameterTypes.get(parameterIndex), parameterAnnotations.get(parameterIndex), inputFunction, false));
}
checkArgument(actualStateDetails.equals(stateDetails), "Expected input function to have state parameters %s, but has %s", stateDetails, actualStateDetails);
}
}

checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions");
Expand All @@ -249,6 +296,14 @@ private static List<Class<?>> getNonDependencyParameterTypes(Method function)
.collect(toImmutableList());
}

private static List<List<Annotation>> getNonDependencyParameterAnnotations(Method function)
{
Annotation[][] parameterAnnotations = function.getParameterAnnotations();
return getNonDependencyParameters(function)
.mapToObj(index -> ImmutableList.copyOf(parameterAnnotations[index]))
.collect(toImmutableList());
}

private static Optional<Method> getRemoveInputFunction(Class<?> clazz, Method inputFunction)
{
// Only include methods which take the same parameters as the corresponding input function
Expand All @@ -258,29 +313,51 @@ private static Optional<Method> getRemoveInputFunction(Class<?> clazz, Method in
.collect(MoreCollectors.toOptional());
}

private static AccumulatorStateDetails getStateDetails(Class<?> clazz)
private static List<AccumulatorStateDetails> getStateDetails(Class<?> clazz)
{
ImmutableSet.Builder<AccumulatorStateDetails> builder = ImmutableSet.builder();
ImmutableSet.Builder<List<AccumulatorStateDetails>> builder = ImmutableSet.builder();
for (Method inputFunction : FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) {
checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters");
int aggregationStateParamIndex = AggregationImplementation.Parser.findAggregationStateParamId(inputFunction);
Class<? extends AccumulatorState> stateClass = inputFunction.getParameterTypes()[aggregationStateParamIndex].asSubclass(AccumulatorState.class);

Optional<TypeSignature> stateType = Arrays.stream(inputFunction.getParameterAnnotations()[aggregationStateParamIndex])
.filter(AggregationState.class::isInstance)
.map(AggregationState.class::cast)
.findFirst()
.map(AggregationState::value)
.filter(type -> !type.isEmpty())
.map(TypeSignature::new);

builder.add(new AccumulatorStateDetails(stateClass, stateType));
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(inputFunction);
checkArgument(!parameterTypes.isEmpty(), "Input function has no parameters");
List<List<Annotation>> parameterAnnotations = getNonDependencyParameterAnnotations(inputFunction);

ImmutableList.Builder<AccumulatorStateDetails> stateParameters = ImmutableList.builder();
for (int parameterIndex = 0; parameterIndex < parameterTypes.size(); parameterIndex++) {
Class<?> parameterType = parameterTypes.get(parameterIndex);
if (!AccumulatorState.class.isAssignableFrom(parameterType)) {
continue;
}

stateParameters.add(toAccumulatorStateDetails(parameterType, parameterAnnotations.get(parameterIndex), inputFunction, false));
}
List<AccumulatorStateDetails> states = stateParameters.build();
checkArgument(!states.isEmpty(), "Input function must have at least one state parameter");
builder.add(states);
}
Set<List<AccumulatorStateDetails>> functionStateClasses = builder.build();
checkArgument(!functionStateClasses.isEmpty(), "No input functions found");
checkArgument(functionStateClasses.size() == 1, "There must be exactly one set of @AccumulatorState in class %s", clazz.toGenericString());

return getOnlyElement(functionStateClasses);
}

private static AccumulatorStateDetails toAccumulatorStateDetails(Class<?> parameterType, List<Annotation> parameterAnnotations, Method method, boolean requireAnnotation)
{
Optional<AggregationState> state = parameterAnnotations.stream()
.filter(AggregationState.class::isInstance)
.map(AggregationState.class::cast)
.findFirst();

if (requireAnnotation) {
checkArgument(state.isPresent(), "AggregationState must be present on AccumulatorState parameters: %s", method);
}
Set<AccumulatorStateDetails> stateClasses = builder.build();
checkArgument(!stateClasses.isEmpty(), "No input functions found");
checkArgument(stateClasses.size() == 1, "There must be exactly one @AccumulatorState in class %s", clazz.toGenericString());

return getOnlyElement(stateClasses);
Optional<TypeSignature> stateSqlType = state.map(AggregationState::value)
.filter(type -> !type.isEmpty())
.map(TypeSignature::new);

AccumulatorStateDetails accumulatorStateDetails = new AccumulatorStateDetails(parameterType.asSubclass(AccumulatorState.class), stateSqlType);
return accumulatorStateDetails;
}

public static class AccumulatorStateDetails
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.StringJoiner;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.operator.ParametricFunctionHelpers.bindDependencies;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod;
import static io.trino.operator.aggregation.state.StateCompiler.generateInOutStateFactory;
Expand All @@ -62,18 +63,18 @@ public class ParametricAggregation
extends SqlAggregationFunction
{
private final ParametricImplementationsGroup<AggregationImplementation> implementations;
private final AccumulatorStateDetails stateDetails;
private final List<AccumulatorStateDetails> stateDetails;

public ParametricAggregation(
Signature signature,
AggregationHeader details,
AccumulatorStateDetails stateDetails,
List<AccumulatorStateDetails> stateDetails,
ParametricImplementationsGroup<AggregationImplementation> implementations)
{
super(
createFunctionMetadata(signature, details, implementations.getFunctionNullability()),
createAggregationFunctionMetadata(details, stateDetails));
this.stateDetails = requireNonNull(stateDetails, "stateDetails is null");
this.stateDetails = ImmutableList.copyOf(requireNonNull(stateDetails, "stateDetails is null"));
checkArgument(implementations.getFunctionNullability().isReturnNullable(), "currently aggregates are required to be nullable");
this.implementations = requireNonNull(implementations, "implementations is null");
}
Expand Down Expand Up @@ -106,14 +107,16 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Aggr
return functionMetadata.build();
}

private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, AccumulatorStateDetails stateDetails)
private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, List<AccumulatorStateDetails> stateDetails)
{
AggregationFunctionMetadataBuilder builder = AggregationFunctionMetadata.builder();
if (details.isOrderSensitive()) {
builder.orderSensitive();
}
if (details.isDecomposable()) {
builder.intermediateType(getSerializedType(stateDetails));
for (AccumulatorStateDetails stateDetail : stateDetails) {
builder.intermediateType(getSerializedType(stateDetail));
}
}
return builder.build();
}
Expand Down Expand Up @@ -150,7 +153,9 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep
AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature);

// Build state factory and serializer
AccumulatorStateDescriptor<?> accumulatorStateDescriptor = generateAccumulatorStateDescriptor(getFunctionMetadata().getSignature(), boundSignature, stateDetails);
List<AccumulatorStateDescriptor<?>> accumulatorStateDescriptors = stateDetails.stream()
.map(state -> generateAccumulatorStateDescriptor(getFunctionMetadata().getSignature(), boundSignature, state))
.collect(toImmutableList());

// Bind provided dependencies to aggregation method handlers
FunctionMetadata metadata = getFunctionMetadata();
Expand Down Expand Up @@ -179,7 +184,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep
removeInputHandle,
combineHandle,
outputHandle,
ImmutableList.of(accumulatorStateDescriptor));
accumulatorStateDescriptors);
}

private static AccumulatorStateDescriptor<?> generateAccumulatorStateDescriptor(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails)
Expand Down Expand Up @@ -226,9 +231,10 @@ private static <T extends AccumulatorState> AccumulatorStateDescriptor<T> genera
generateStateFactory(stateClass));
}

public Class<?> getStateClass()
@VisibleForTesting
public List<AccumulatorStateDetails> getStateDetails()
{
return stateDetails.getStateClass();
return stateDetails;
}

@VisibleForTesting
Expand Down
Loading

0 comments on commit b8d6287

Please sign in to comment.