Skip to content

Commit

Permalink
Add support for generic in-out state to annotated aggregation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent b343bba commit d358c57
Show file tree
Hide file tree
Showing 5 changed files with 526 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.FunctionDependency;
import io.trino.spi.function.InputFunction;
Expand All @@ -39,7 +40,9 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
Expand All @@ -65,11 +68,11 @@ public static List<ParametricAggregation> parseFunctionDefinitions(Class<?> aggr
ImmutableList.Builder<ParametricAggregation> functions = ImmutableList.builder();

// There must be a single state class and combine function
Class<? extends AccumulatorState> stateClass = getStateClass(aggregationDefinition);
Optional<Method> combineFunction = getCombineFunction(aggregationDefinition, stateClass);
AccumulatorStateDetails stateDetails = getStateDetails(aggregationDefinition);
Optional<Method> combineFunction = getCombineFunction(aggregationDefinition, stateDetails);

// Each output function defines a new aggregation function
for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) {
for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateDetails)) {
AggregationHeader header = parseHeader(aggregationDefinition, outputFunction);
if (header.isDecomposable()) {
checkArgument(combineFunction.isPresent(), "Decomposable method %s does not have a combine method", header.getName());
Expand All @@ -81,7 +84,7 @@ else if (combineFunction.isPresent()) {
// Input functions can have either an exact signature, or generic/calculate signature
List<AggregationImplementation> exactImplementations = new ArrayList<>();
List<AggregationImplementation> nonExactImplementations = new ArrayList<>();
for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) {
for (Method inputFunction : getInputFunctions(aggregationDefinition, stateDetails)) {
Optional<Method> removeInputFunction = getRemoveInputFunction(aggregationDefinition, inputFunction);
AggregationImplementation implementation = parseImplementation(
aggregationDefinition,
Expand All @@ -99,9 +102,9 @@ else if (combineFunction.isPresent()) {
}

// register a set functions for the canonical name, and each alias
functions.addAll(buildFunctions(header.getName(), header, stateClass, exactImplementations, nonExactImplementations));
functions.addAll(buildFunctions(header.getName(), header, stateDetails, exactImplementations, nonExactImplementations));
for (String alias : getAliases(aggregationDefinition.getAnnotation(AggregationFunction.class), outputFunction)) {
functions.addAll(buildFunctions(alias, header, stateClass, exactImplementations, nonExactImplementations));
functions.addAll(buildFunctions(alias, header, stateDetails, exactImplementations, nonExactImplementations));
}
}

Expand All @@ -111,7 +114,7 @@ else if (combineFunction.isPresent()) {
private static List<ParametricAggregation> buildFunctions(
String name,
AggregationHeader header,
Class<? extends AccumulatorState> stateClass,
AccumulatorStateDetails stateDetails,
List<AggregationImplementation> exactImplementations,
List<AggregationImplementation> nonExactImplementations)
{
Expand All @@ -122,7 +125,7 @@ private static List<ParametricAggregation> buildFunctions(
functions.add(new ParametricAggregation(
exactImplementation.getSignature().withName(name),
header,
stateClass,
stateDetails,
ParametricImplementationsGroup.of(exactImplementation).withAlias(name)));
}

Expand All @@ -134,7 +137,7 @@ private static List<ParametricAggregation> buildFunctions(
functions.add(new ParametricAggregation(
implementations.getSignature().withName(name),
header,
stateClass,
stateDetails,
implementations.withAlias(name)));
}

Expand Down Expand Up @@ -180,27 +183,27 @@ private static List<String> getAliases(AggregationFunction aggregationAnnotation
return ImmutableList.copyOf(aggregationAnnotation.alias());
}

private static Optional<Method> getCombineFunction(Class<?> clazz, Class<?> stateClass)
private static Optional<Method> getCombineFunction(Class<?> clazz, AccumulatorStateDetails stateDetails)
{
List<Method> combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class);
for (Method combineFunction : combineFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(combineFunction);
List<Class<?>> expectedParameterTypes = nCopies(2, stateClass);
List<Class<?>> expectedParameterTypes = nCopies(2, stateDetails.getStateClass());
checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction);
}
checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateClass.toGenericString());
checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateDetails.getStateClass().toGenericString());
return combineFunctions.stream().findFirst();
}

private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateClass)
private static List<Method> getOutputFunctions(Class<?> clazz, AccumulatorStateDetails stateDetails)
{
List<Method> outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class);
for (Method outputFunction : outputFunctions) {
// verify parameter types
List<Class<?>> parameterTypes = getNonDependencyParameterTypes(outputFunction);
List<Class<?>> expectedParameterTypes = ImmutableList.<Class<?>>builder()
.add(stateClass)
.add(stateDetails.getStateClass())
.add(BlockBuilder.class)
.build();
checkArgument(parameterTypes.equals(expectedParameterTypes),
Expand All @@ -212,15 +215,15 @@ private static List<Method> getOutputFunctions(Class<?> clazz, Class<?> stateCla
return outputFunctions;
}

private static List<Method> getInputFunctions(Class<?> clazz, Class<?> stateClass)
private static List<Method> getInputFunctions(Class<?> clazz, AccumulatorStateDetails stateDetails)
{
List<Method> inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class);
for (Method inputFunction : inputFunctions) {
// verify state parameter is first non-dependency parameter
Class<?> actualStateType = getNonDependencyParameterTypes(inputFunction).get(0);
checkArgument(stateClass.equals(actualStateType),
checkArgument(stateDetails.getStateClass().equals(actualStateType),
"Expected input function non-dependency parameters to begin with state type %s: %s",
stateClass.getSimpleName(),
stateDetails.getStateClass().getSimpleName(),
inputFunction);
}

Expand Down Expand Up @@ -255,20 +258,69 @@ private static Optional<Method> getRemoveInputFunction(Class<?> clazz, Method in
.collect(MoreCollectors.toOptional());
}

private static Class<? extends AccumulatorState> getStateClass(Class<?> clazz)
private static AccumulatorStateDetails getStateDetails(Class<?> clazz)
{
ImmutableSet.Builder<Class<? extends AccumulatorState>> builder = ImmutableSet.builder();
ImmutableSet.Builder<AccumulatorStateDetails> builder = ImmutableSet.builder();
for (Method inputFunction : FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) {
checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters");
Class<?> stateClass = AggregationImplementation.Parser.findAggregationStateParamType(inputFunction);
int aggregationStateParamIndex = AggregationImplementation.Parser.findAggregationStateParamId(inputFunction);
Class<? extends AccumulatorState> stateClass = inputFunction.getParameterTypes()[aggregationStateParamIndex].asSubclass(AccumulatorState.class);

checkArgument(AccumulatorState.class.isAssignableFrom(stateClass), "stateClass is not a subclass of AccumulatorState");
builder.add(stateClass.asSubclass(AccumulatorState.class));
Optional<TypeSignature> stateType = Arrays.stream(inputFunction.getParameterAnnotations()[aggregationStateParamIndex])
.filter(AggregationState.class::isInstance)
.map(AggregationState.class::cast)
.findFirst()
.map(AggregationState::value)
.filter(type -> !type.isEmpty())
.map(TypeSignature::new);

builder.add(new AccumulatorStateDetails(stateClass, stateType));
}
ImmutableSet<Class<? extends AccumulatorState>> stateClasses = builder.build();
Set<AccumulatorStateDetails> stateClasses = builder.build();
checkArgument(!stateClasses.isEmpty(), "No input functions found");
checkArgument(stateClasses.size() == 1, "There must be exactly one @AccumulatorState in class %s", clazz.toGenericString());

return getOnlyElement(stateClasses);
}

public static class AccumulatorStateDetails
{
private final Class<? extends AccumulatorState> stateClass;
private final Optional<TypeSignature> type;

public AccumulatorStateDetails(Class<? extends AccumulatorState> stateClass, Optional<TypeSignature> type)
{
this.stateClass = requireNonNull(stateClass, "stateClass is null");
this.type = requireNonNull(type, "type is null");
}

public Class<? extends AccumulatorState> getStateClass()
{
return stateClass;
}

public Optional<TypeSignature> getStateType()
{
return type;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AccumulatorStateDetails that = (AccumulatorStateDetails) o;
return Objects.equals(stateClass, that.stateClass) && Objects.equals(type, that.type);
}

@Override
public int hashCode()
{
return Objects.hash(stateClass, type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,6 @@ public List<TypeSignature> getInputTypesSignatures(Method inputFunction)
return builder.build();
}

public static Class<?> findAggregationStateParamType(Method inputFunction)
{
return inputFunction.getParameterTypes()[findAggregationStateParamId(inputFunction)];
}

public static int findAggregationStateParamId(Method method)
{
return findAggregationStateParamId(method, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@
import io.trino.metadata.SignatureBinder;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.ParametricImplementationsGroup;
import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails;
import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.InOutStateSerializer;
import io.trino.operator.annotations.ImplementationDependency;
import io.trino.spi.TrinoException;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.InOut;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.util.Collection;
Expand All @@ -43,6 +49,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.operator.ParametricFunctionHelpers.bindDependencies;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod;
import static io.trino.operator.aggregation.state.StateCompiler.generateInOutStateFactory;
import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory;
import static io.trino.operator.aggregation.state.StateCompiler.generateStateSerializer;
import static io.trino.operator.aggregation.state.StateCompiler.getSerializedType;
Expand All @@ -55,18 +62,18 @@ public class ParametricAggregation
extends SqlAggregationFunction
{
private final ParametricImplementationsGroup<AggregationImplementation> implementations;
private final Class<? extends AccumulatorState> stateClass;
private final AccumulatorStateDetails stateDetails;

public ParametricAggregation(
Signature signature,
AggregationHeader details,
Class<? extends AccumulatorState> stateClass,
AccumulatorStateDetails stateDetails,
ParametricImplementationsGroup<AggregationImplementation> implementations)
{
super(
createFunctionMetadata(signature, details, implementations.getFunctionNullability()),
createAggregationFunctionMetadata(details, stateClass));
this.stateClass = requireNonNull(stateClass, "stateClass is null");
createAggregationFunctionMetadata(details, stateDetails));
this.stateDetails = requireNonNull(stateDetails, "stateDetails is null");
checkArgument(implementations.getFunctionNullability().isReturnNullable(), "currently aggregates are required to be nullable");
this.implementations = requireNonNull(implementations, "implementations is null");
}
Expand Down Expand Up @@ -99,14 +106,14 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Aggr
return functionMetadata.build();
}

private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, Class<? extends AccumulatorState> stateClass)
private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, AccumulatorStateDetails stateDetails)
{
AggregationFunctionMetadataBuilder builder = AggregationFunctionMetadata.builder();
if (details.isOrderSensitive()) {
builder.orderSensitive();
}
if (details.isDecomposable()) {
builder.intermediateType(getSerializedType(stateClass).getTypeSignature());
builder.intermediateType(getSerializedType(stateDetails));
}
return builder.build();
}
Expand Down Expand Up @@ -143,7 +150,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep
AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature);

// Build state factory and serializer
AccumulatorStateDescriptor<?> accumulatorStateDescriptor = generateAccumulatorStateDescriptor(stateClass);
AccumulatorStateDescriptor<?> accumulatorStateDescriptor = generateAccumulatorStateDescriptor(getFunctionMetadata().getSignature(), boundSignature, stateDetails);

// Bind provided dependencies to aggregation method handlers
FunctionMetadata metadata = getFunctionMetadata();
Expand Down Expand Up @@ -175,6 +182,42 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep
ImmutableList.of(accumulatorStateDescriptor));
}

private static AccumulatorStateDescriptor<?> generateAccumulatorStateDescriptor(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails)
{
if (stateDetails.getStateClass().equals(InOut.class)) {
return createInOutAccumulatorStateDescriptor(signature, boundSignature, stateDetails);
}
return generateAccumulatorStateDescriptor(stateDetails.getStateClass());
}

private static AccumulatorStateDescriptor<InOut> createInOutAccumulatorStateDescriptor(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails)
{
Type type = extractInOutType(signature, boundSignature, stateDetails);
InOutStateSerializer inOutStateSerializer = new InOutStateSerializer(type);
AccumulatorStateFactory<InOut> inOutAccumulatorStateFactory = generateInOutStateFactory(type);
return new AccumulatorStateDescriptor<>(
InOut.class,
inOutStateSerializer,
inOutAccumulatorStateFactory);
}

private static Type extractInOutType(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails)
{
TypeSignature inOutType = stateDetails.getStateType().orElseThrow();
if (signature.getReturnType().equals(inOutType)) {
return boundSignature.getReturnType();
}
List<TypeSignature> declaredArgumentTypes = signature.getArgumentTypes();
List<Type> actualArgumentTypes = boundSignature.getArgumentTypes();
for (int i = 0; i < declaredArgumentTypes.size(); i++) {
TypeSignature argumentType = declaredArgumentTypes.get(i);
if (argumentType.equals(inOutType)) {
return actualArgumentTypes.get(i);
}
}
throw new IllegalArgumentException(format("Could not determine type %s from function signature %s", inOutType, signature));
}

private static <T extends AccumulatorState> AccumulatorStateDescriptor<T> generateAccumulatorStateDescriptor(Class<T> stateClass)
{
return new AccumulatorStateDescriptor<>(
Expand All @@ -185,7 +228,7 @@ private static <T extends AccumulatorState> AccumulatorStateDescriptor<T> genera

public Class<?> getStateClass()
{
return stateClass;
return stateDetails.getStateClass();
}

@VisibleForTesting
Expand Down
Loading

0 comments on commit d358c57

Please sign in to comment.