From d358c57cf63b94616924e27b4c415ec90e5e0106 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Fri, 24 Dec 2021 15:20:55 -0800 Subject: [PATCH] Add support for generic in-out state to annotated aggregation functions --- .../AggregationFromAnnotationsParser.java | 98 +++-- .../AggregationImplementation.java | 5 - .../aggregation/ParametricAggregation.java | 59 ++- .../state/InOutStateSerializer.java | 51 +++ .../aggregation/state/StateCompiler.java | 352 +++++++++++++++++- 5 files changed, 526 insertions(+), 39 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/operator/aggregation/state/InOutStateSerializer.java diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 389a00f32752..21c9e2c54518 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationState; import io.trino.spi.function.CombineFunction; import io.trino.spi.function.FunctionDependency; import io.trino.spi.function.InputFunction; @@ -39,7 +40,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; @@ -65,11 +68,11 @@ public static List parseFunctionDefinitions(Class aggr ImmutableList.Builder functions = ImmutableList.builder(); // There must be a single state class and combine function - Class stateClass = getStateClass(aggregationDefinition); - Optional combineFunction = getCombineFunction(aggregationDefinition, stateClass); + AccumulatorStateDetails stateDetails = getStateDetails(aggregationDefinition); + Optional combineFunction = getCombineFunction(aggregationDefinition, stateDetails); // Each output function defines a new aggregation function - for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) { + for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateDetails)) { AggregationHeader header = parseHeader(aggregationDefinition, outputFunction); if (header.isDecomposable()) { checkArgument(combineFunction.isPresent(), "Decomposable method %s does not have a combine method", header.getName()); @@ -81,7 +84,7 @@ else if (combineFunction.isPresent()) { // Input functions can have either an exact signature, or generic/calculate signature List exactImplementations = new ArrayList<>(); List nonExactImplementations = new ArrayList<>(); - for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) { + for (Method inputFunction : getInputFunctions(aggregationDefinition, stateDetails)) { Optional removeInputFunction = getRemoveInputFunction(aggregationDefinition, inputFunction); AggregationImplementation implementation = parseImplementation( aggregationDefinition, @@ -99,9 +102,9 @@ else if (combineFunction.isPresent()) { } // register a set functions for the canonical name, and each alias - functions.addAll(buildFunctions(header.getName(), header, stateClass, exactImplementations, nonExactImplementations)); + functions.addAll(buildFunctions(header.getName(), header, stateDetails, exactImplementations, nonExactImplementations)); for (String alias : getAliases(aggregationDefinition.getAnnotation(AggregationFunction.class), outputFunction)) { - functions.addAll(buildFunctions(alias, header, stateClass, exactImplementations, nonExactImplementations)); + functions.addAll(buildFunctions(alias, header, stateDetails, exactImplementations, nonExactImplementations)); } } @@ -111,7 +114,7 @@ else if (combineFunction.isPresent()) { private static List buildFunctions( String name, AggregationHeader header, - Class stateClass, + AccumulatorStateDetails stateDetails, List exactImplementations, List nonExactImplementations) { @@ -122,7 +125,7 @@ private static List buildFunctions( functions.add(new ParametricAggregation( exactImplementation.getSignature().withName(name), header, - stateClass, + stateDetails, ParametricImplementationsGroup.of(exactImplementation).withAlias(name))); } @@ -134,7 +137,7 @@ private static List buildFunctions( functions.add(new ParametricAggregation( implementations.getSignature().withName(name), header, - stateClass, + stateDetails, implementations.withAlias(name))); } @@ -180,27 +183,27 @@ private static List getAliases(AggregationFunction aggregationAnnotation return ImmutableList.copyOf(aggregationAnnotation.alias()); } - private static Optional getCombineFunction(Class clazz, Class stateClass) + private static Optional getCombineFunction(Class clazz, AccumulatorStateDetails stateDetails) { List combineFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class); for (Method combineFunction : combineFunctions) { // verify parameter types List> parameterTypes = getNonDependencyParameterTypes(combineFunction); - List> expectedParameterTypes = nCopies(2, stateClass); + List> expectedParameterTypes = nCopies(2, stateDetails.getStateClass()); checkArgument(parameterTypes.equals(expectedParameterTypes), "Expected combine function non-dependency parameters to be %s: %s", expectedParameterTypes, combineFunction); } - checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateClass.toGenericString()); + checkArgument(combineFunctions.size() <= 1, "There must be only one @CombineFunction in class %s for the @AggregationState %s", clazz.toGenericString(), stateDetails.getStateClass().toGenericString()); return combineFunctions.stream().findFirst(); } - private static List getOutputFunctions(Class clazz, Class stateClass) + private static List getOutputFunctions(Class clazz, AccumulatorStateDetails stateDetails) { List outputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class); for (Method outputFunction : outputFunctions) { // verify parameter types List> parameterTypes = getNonDependencyParameterTypes(outputFunction); List> expectedParameterTypes = ImmutableList.>builder() - .add(stateClass) + .add(stateDetails.getStateClass()) .add(BlockBuilder.class) .build(); checkArgument(parameterTypes.equals(expectedParameterTypes), @@ -212,15 +215,15 @@ private static List getOutputFunctions(Class clazz, Class stateCla return outputFunctions; } - private static List getInputFunctions(Class clazz, Class stateClass) + private static List getInputFunctions(Class clazz, AccumulatorStateDetails stateDetails) { List inputFunctions = FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class); for (Method inputFunction : inputFunctions) { // verify state parameter is first non-dependency parameter Class actualStateType = getNonDependencyParameterTypes(inputFunction).get(0); - checkArgument(stateClass.equals(actualStateType), + checkArgument(stateDetails.getStateClass().equals(actualStateType), "Expected input function non-dependency parameters to begin with state type %s: %s", - stateClass.getSimpleName(), + stateDetails.getStateClass().getSimpleName(), inputFunction); } @@ -255,20 +258,69 @@ private static Optional getRemoveInputFunction(Class clazz, Method in .collect(MoreCollectors.toOptional()); } - private static Class getStateClass(Class clazz) + private static AccumulatorStateDetails getStateDetails(Class clazz) { - ImmutableSet.Builder> builder = ImmutableSet.builder(); + ImmutableSet.Builder builder = ImmutableSet.builder(); for (Method inputFunction : FunctionsParserHelper.findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) { checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters"); - Class stateClass = AggregationImplementation.Parser.findAggregationStateParamType(inputFunction); + int aggregationStateParamIndex = AggregationImplementation.Parser.findAggregationStateParamId(inputFunction); + Class stateClass = inputFunction.getParameterTypes()[aggregationStateParamIndex].asSubclass(AccumulatorState.class); - checkArgument(AccumulatorState.class.isAssignableFrom(stateClass), "stateClass is not a subclass of AccumulatorState"); - builder.add(stateClass.asSubclass(AccumulatorState.class)); + Optional stateType = Arrays.stream(inputFunction.getParameterAnnotations()[aggregationStateParamIndex]) + .filter(AggregationState.class::isInstance) + .map(AggregationState.class::cast) + .findFirst() + .map(AggregationState::value) + .filter(type -> !type.isEmpty()) + .map(TypeSignature::new); + + builder.add(new AccumulatorStateDetails(stateClass, stateType)); } - ImmutableSet> stateClasses = builder.build(); + Set stateClasses = builder.build(); checkArgument(!stateClasses.isEmpty(), "No input functions found"); checkArgument(stateClasses.size() == 1, "There must be exactly one @AccumulatorState in class %s", clazz.toGenericString()); return getOnlyElement(stateClasses); } + + public static class AccumulatorStateDetails + { + private final Class stateClass; + private final Optional type; + + public AccumulatorStateDetails(Class stateClass, Optional type) + { + this.stateClass = requireNonNull(stateClass, "stateClass is null"); + this.type = requireNonNull(type, "type is null"); + } + + public Class getStateClass() + { + return stateClass; + } + + public Optional getStateType() + { + return type; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AccumulatorStateDetails that = (AccumulatorStateDetails) o; + return Objects.equals(stateClass, that.stateClass) && Objects.equals(type, that.type); + } + + @Override + public int hashCode() + { + return Objects.hash(stateClass, type); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java index 9350b50e31f8..dc86987ed293 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java @@ -501,11 +501,6 @@ public List getInputTypesSignatures(Method inputFunction) return builder.build(); } - public static Class findAggregationStateParamType(Method inputFunction) - { - return inputFunction.getParameterTypes()[findAggregationStateParamId(inputFunction)]; - } - public static int findAggregationStateParamId(Method method) { return findAggregationStateParamId(method, 0); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java index 031441440220..3609065d6e97 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java @@ -28,11 +28,17 @@ import io.trino.metadata.SignatureBinder; import io.trino.metadata.SqlAggregationFunction; import io.trino.operator.ParametricImplementationsGroup; +import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; +import io.trino.operator.aggregation.state.InOutStateSerializer; import io.trino.operator.annotations.ImplementationDependency; import io.trino.spi.TrinoException; import io.trino.spi.function.AccumulatorState; +import io.trino.spi.function.AccumulatorStateFactory; +import io.trino.spi.function.InOut; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; import java.util.Collection; @@ -43,6 +49,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.operator.ParametricFunctionHelpers.bindDependencies; import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; +import static io.trino.operator.aggregation.state.StateCompiler.generateInOutStateFactory; import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory; import static io.trino.operator.aggregation.state.StateCompiler.generateStateSerializer; import static io.trino.operator.aggregation.state.StateCompiler.getSerializedType; @@ -55,18 +62,18 @@ public class ParametricAggregation extends SqlAggregationFunction { private final ParametricImplementationsGroup implementations; - private final Class stateClass; + private final AccumulatorStateDetails stateDetails; public ParametricAggregation( Signature signature, AggregationHeader details, - Class stateClass, + AccumulatorStateDetails stateDetails, ParametricImplementationsGroup implementations) { super( createFunctionMetadata(signature, details, implementations.getFunctionNullability()), - createAggregationFunctionMetadata(details, stateClass)); - this.stateClass = requireNonNull(stateClass, "stateClass is null"); + createAggregationFunctionMetadata(details, stateDetails)); + this.stateDetails = requireNonNull(stateDetails, "stateDetails is null"); checkArgument(implementations.getFunctionNullability().isReturnNullable(), "currently aggregates are required to be nullable"); this.implementations = requireNonNull(implementations, "implementations is null"); } @@ -99,14 +106,14 @@ private static FunctionMetadata createFunctionMetadata(Signature signature, Aggr return functionMetadata.build(); } - private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, Class stateClass) + private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader details, AccumulatorStateDetails stateDetails) { AggregationFunctionMetadataBuilder builder = AggregationFunctionMetadata.builder(); if (details.isOrderSensitive()) { builder.orderSensitive(); } if (details.isDecomposable()) { - builder.intermediateType(getSerializedType(stateClass).getTypeSignature()); + builder.intermediateType(getSerializedType(stateDetails)); } return builder.build(); } @@ -143,7 +150,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature); // Build state factory and serializer - AccumulatorStateDescriptor accumulatorStateDescriptor = generateAccumulatorStateDescriptor(stateClass); + AccumulatorStateDescriptor accumulatorStateDescriptor = generateAccumulatorStateDescriptor(getFunctionMetadata().getSignature(), boundSignature, stateDetails); // Bind provided dependencies to aggregation method handlers FunctionMetadata metadata = getFunctionMetadata(); @@ -175,6 +182,42 @@ public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDep ImmutableList.of(accumulatorStateDescriptor)); } + private static AccumulatorStateDescriptor generateAccumulatorStateDescriptor(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails) + { + if (stateDetails.getStateClass().equals(InOut.class)) { + return createInOutAccumulatorStateDescriptor(signature, boundSignature, stateDetails); + } + return generateAccumulatorStateDescriptor(stateDetails.getStateClass()); + } + + private static AccumulatorStateDescriptor createInOutAccumulatorStateDescriptor(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails) + { + Type type = extractInOutType(signature, boundSignature, stateDetails); + InOutStateSerializer inOutStateSerializer = new InOutStateSerializer(type); + AccumulatorStateFactory inOutAccumulatorStateFactory = generateInOutStateFactory(type); + return new AccumulatorStateDescriptor<>( + InOut.class, + inOutStateSerializer, + inOutAccumulatorStateFactory); + } + + private static Type extractInOutType(Signature signature, BoundSignature boundSignature, AccumulatorStateDetails stateDetails) + { + TypeSignature inOutType = stateDetails.getStateType().orElseThrow(); + if (signature.getReturnType().equals(inOutType)) { + return boundSignature.getReturnType(); + } + List declaredArgumentTypes = signature.getArgumentTypes(); + List actualArgumentTypes = boundSignature.getArgumentTypes(); + for (int i = 0; i < declaredArgumentTypes.size(); i++) { + TypeSignature argumentType = declaredArgumentTypes.get(i); + if (argumentType.equals(inOutType)) { + return actualArgumentTypes.get(i); + } + } + throw new IllegalArgumentException(format("Could not determine type %s from function signature %s", inOutType, signature)); + } + private static AccumulatorStateDescriptor generateAccumulatorStateDescriptor(Class stateClass) { return new AccumulatorStateDescriptor<>( @@ -185,7 +228,7 @@ private static AccumulatorStateDescriptor genera public Class getStateClass() { - return stateClass; + return stateDetails.getStateClass(); } @VisibleForTesting diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/InOutStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/InOutStateSerializer.java new file mode 100644 index 000000000000..43acbcae860e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/InOutStateSerializer.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation.state; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.InOut; +import io.trino.spi.type.Type; + +import static java.util.Objects.requireNonNull; + +public final class InOutStateSerializer + implements AccumulatorStateSerializer +{ + private final Type type; + + public InOutStateSerializer(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + @Override + public Type getSerializedType() + { + return type; + } + + @Override + public void serialize(InOut state, BlockBuilder out) + { + state.get(out); + } + + @Override + public void deserialize(Block block, int index, InOut state) + { + state.set(block, index); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java index b21d1a84db71..ce1b6e6856b3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/StateCompiler.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Ordering; import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.BytecodeNode; import io.airlift.bytecode.ClassDefinition; import io.airlift.bytecode.DynamicClassLoader; import io.airlift.bytecode.FieldDefinition; @@ -36,14 +37,19 @@ import io.trino.array.LongBigArray; import io.trino.array.ObjectBigArray; import io.trino.array.SliceBigArray; +import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateMetadata; import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.GroupedAccumulatorState; +import io.trino.spi.function.InOut; +import io.trino.spi.function.InternalDataAccessor; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.CallSiteBinder; import io.trino.sql.gen.SqlTypeBytecodeExpression; import org.openjdk.jol.info.ClassLayout; @@ -59,6 +65,8 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; import static com.google.common.base.CaseFormat.LOWER_CAMEL; import static com.google.common.base.CaseFormat.UPPER_CAMEL; @@ -75,10 +83,13 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.add; import static io.airlift.bytecode.expression.BytecodeExpressions.constantBoolean; import static io.airlift.bytecode.expression.BytecodeExpressions.constantClass; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull; import static io.airlift.bytecode.expression.BytecodeExpressions.constantNumber; import static io.airlift.bytecode.expression.BytecodeExpressions.constantString; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; import static io.airlift.bytecode.expression.BytecodeExpressions.defaultValue; import static io.airlift.bytecode.expression.BytecodeExpressions.equal; import static io.airlift.bytecode.expression.BytecodeExpressions.getStatic; @@ -127,20 +138,25 @@ private static Class getBigArrayType(Class type) return ObjectBigArray.class; } - public static Type getSerializedType(Class clazz) + public static TypeSignature getSerializedType(AccumulatorStateDetails stateDetails) { + if (stateDetails.getStateType().isPresent()) { + return stateDetails.getStateType().get(); + } + + Class clazz = stateDetails.getStateClass(); AccumulatorStateMetadata metadata = getMetadataAnnotation(clazz); if (metadata != null && metadata.stateSerializerClass() != AccumulatorStateSerializer.class) { try { AccumulatorStateSerializer stateSerializer = (AccumulatorStateSerializer) metadata.stateSerializerClass().getConstructor().newInstance(); - return stateSerializer.getSerializedType(); + return stateSerializer.getSerializedType().getTypeSignature(); } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { throw new RuntimeException(e); } } - return getSerializedType(enumerateFields(clazz, ImmutableMap.of())); + return getSerializedType(enumerateFields(clazz, ImmutableMap.of())).getTypeSignature(); } public static AccumulatorStateSerializer generateStateSerializer(Class clazz) @@ -350,6 +366,327 @@ private static Method getGetter(Class clazz, StateField field) } } + public static AccumulatorStateFactory generateInOutStateFactory(Type type) + { + CallSiteBinder callSiteBinder = new CallSiteBinder(); + ClassDefinition singleStateClassDefinition = generateInOutSingleStateClass(type, callSiteBinder); + ClassDefinition groupedStateClassDefinition = generateInOutGroupedStateClass(type, callSiteBinder); + + DynamicClassLoader classLoader = new DynamicClassLoader(StateCompiler.class.getClassLoader(), callSiteBinder.getBindings()); + Class singleStateClass = defineClass(singleStateClassDefinition, InOut.class, classLoader); + Class groupedStateClass = defineClass(groupedStateClassDefinition, InOut.class, classLoader); + + return generateStateFactory(InOut.class, singleStateClass, groupedStateClass, classLoader); + } + + private static ClassDefinition generateInOutSingleStateClass(Type type, CallSiteBinder callSiteBinder) + { + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("SingleInOut"), + type(Object.class), + type(InOut.class), + type(InternalDataAccessor.class)); + + estimatedSize(definition); + + // Generate constructor + MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); + + constructor.getBody() + .append(constructor.getThis()) + .invokeConstructor(Object.class); + + // Generate fields + FieldDefinition valueField = definition.declareField(a(PRIVATE), "value", inOutGetterReturnType(type)); + Function valueGetter = scope -> scope.getThis().getField(valueField); + + Optional nullField; + Function nullGetter; + if (type.getJavaType().isPrimitive()) { + nullField = Optional.of(definition.declareField(a(PRIVATE), "valueIdNull", boolean.class)); + constructor.getBody().append(constructor.getThis().setField(nullField.get(), constantTrue())); + nullGetter = scope -> scope.getThis().getField(nullField.get()); + } + else { + nullField = Optional.empty(); + nullGetter = scope -> isNull(valueGetter.apply(scope)); + } + + constructor.getBody() + .ret(); + + inOutSingleCopy(definition, valueField, nullField); + + Function setNullGenerator = scope -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.setField(field, constantTrue()))); + bytecodeBlock.append(thisVariable.setField(valueField, defaultValue(valueField.getType()))); + return bytecodeBlock; + }; + + BiFunction setValueGenerator = (scope, value) -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.setField(field, constantFalse()))); + bytecodeBlock.append(thisVariable.setField(valueField, value)); + return bytecodeBlock; + }; + + generateInOutMethods(type, definition, valueGetter, nullGetter, setNullGenerator, setValueGenerator, callSiteBinder); + return definition; + } + + private static ClassDefinition generateInOutGroupedStateClass(Type type, CallSiteBinder callSiteBinder) + { + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("GroupedInOut"), // todo add type + type(Object.class), + type(InOut.class), + type(GroupedAccumulatorState.class), + type(InternalDataAccessor.class)); + + estimatedSize(definition); + + MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); + constructor.getBody() + .append(constructor.getThis()) + .invokeConstructor(Object.class); + + FieldDefinition groupIdField = definition.declareField(a(PRIVATE), "groupId", long.class); + + Class valueElementType = inOutGetterReturnType(type); + FieldDefinition valueField = definition.declareField(a(PRIVATE, FINAL), "value", getBigArrayType(valueElementType)); + constructor.getBody().append(constructor.getThis().setField(valueField, newInstance(valueField.getType()))); + Function valueGetter = scope -> scope.getThis().getField(valueField).invoke("get", valueElementType, scope.getThis().getField(groupIdField)); + + Optional nullField; + Function nullGetter; + if (type.getJavaType().isPrimitive()) { + nullField = Optional.of(definition.declareField(a(PRIVATE, FINAL), "valueIdNull", BooleanBigArray.class)); + constructor.getBody().append(constructor.getThis().setField(nullField.get(), newInstance(BooleanBigArray.class, constantTrue()))); + nullGetter = scope -> scope.getThis().getField(nullField.get()).invoke("get", boolean.class, scope.getThis().getField(groupIdField)); + } + else { + nullField = Optional.empty(); + nullGetter = scope -> isNull(valueGetter.apply(scope)); + } + + constructor.getBody() + .ret(); + + inOutGroupedSetGroupId(definition, groupIdField); + inOutGroupedEnsureCapacity(definition, valueField, nullField); + inOutGroupedCopy(definition, valueField, nullField); + + Function setNullGenerator = scope -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.getField(field).invoke("set", void.class, thisVariable.getField(groupIdField), constantTrue()))); + bytecodeBlock.append(thisVariable.getField(valueField).invoke("set", void.class, thisVariable.getField(groupIdField), defaultValue(valueElementType))); + return bytecodeBlock; + }; + BiFunction setValueGenerator = (scope, value) -> { + Variable thisVariable = scope.getThis(); + BytecodeBlock bytecodeBlock = new BytecodeBlock(); + nullField.ifPresent(field -> bytecodeBlock.append(thisVariable.getField(field).invoke("set", void.class, thisVariable.getField(groupIdField), constantFalse()))); + bytecodeBlock.append(thisVariable.getField(valueField).invoke("set", void.class, thisVariable.getField(groupIdField), value.cast(valueElementType))); + return bytecodeBlock; + }; + + generateInOutMethods(type, definition, valueGetter, nullGetter, setNullGenerator, setValueGenerator, callSiteBinder); + + return definition; + } + + private static void generateInOutMethods(Type type, + ClassDefinition definition, + Function valueGetter, + Function nullGetter, + Function setNullGenerator, + BiFunction setValueGenerator, + CallSiteBinder callSiteBinder) + { + SqlTypeBytecodeExpression sqlType = constantType(callSiteBinder, type); + + generateInOutGetType(definition, sqlType); + generateInOutIsNull(definition, nullGetter); + generateInOutGetBlockBuilder(definition, sqlType, valueGetter); + generateInOutSetBlockPosition(definition, sqlType, setNullGenerator, setValueGenerator); + generateInOutSetInOut(definition, type, setNullGenerator, setValueGenerator); + generateInOutGetValue(definition, type, valueGetter); + } + + private static void estimatedSize(ClassDefinition definition) + { + FieldDefinition instanceSize = generateInstanceSize(definition); + + // Add getter for class size + definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class)) + .getBody() + .getStaticField(instanceSize) + .retLong(); + } + + private static void inOutSingleCopy(ClassDefinition definition, FieldDefinition valueField, Optional nullField) + { + MethodDefinition copy = definition.declareMethod(a(PUBLIC), "copy", type(AccumulatorState.class)); + Variable thisVariable = copy.getThis(); + BytecodeBlock body = copy.getBody(); + + Variable copyVariable = copy.getScope().declareVariable(definition.getType(), "copy"); + body.append(copyVariable.set(newInstance(definition.getType()))); + body.append(copyVariable.setField(valueField, thisVariable.getField(valueField))); + nullField.ifPresent(field -> body.append(copyVariable.setField(field, thisVariable.getField(field)))); + body.append(copyVariable.ret()); + } + + private static void inOutGroupedSetGroupId(ClassDefinition definition, FieldDefinition groupIdField) + { + Parameter groupIdArg = arg("groupId", long.class); + MethodDefinition method = definition.declareMethod(a(PUBLIC), "setGroupId", type(void.class), groupIdArg); + method.getBody() + .append(method.getThis().setField(groupIdField, groupIdArg)) + .ret(); + } + + private static void inOutGroupedEnsureCapacity(ClassDefinition definition, FieldDefinition valueField, Optional nullField) + { + Parameter size = arg("size", long.class); + MethodDefinition method = definition.declareMethod(a(PUBLIC), "ensureCapacity", type(void.class), size); + Variable thisVariable = method.getThis(); + BytecodeBlock body = method.getBody(); + + body.append(thisVariable.getField(valueField).invoke("ensureCapacity", void.class, size)); + nullField.ifPresent(field -> body.append(thisVariable.getField(field).invoke("ensureCapacity", void.class, size))); + body.ret(); + } + + private static void inOutGroupedCopy(ClassDefinition definition, FieldDefinition valueField, Optional nullField) + { + MethodDefinition copy = definition.declareMethod(a(PUBLIC), "copy", type(AccumulatorState.class)); + Variable thisVariable = copy.getThis(); + BytecodeBlock body = copy.getBody(); + + Variable copyVariable = copy.getScope().declareVariable(definition.getType(), "copy"); + body.append(copyVariable.set(newInstance(definition.getType()))); + copyBigArray(body, thisVariable, copyVariable, valueField); + nullField.ifPresent(field -> copyBigArray(body, thisVariable, copyVariable, field)); + body.append(copyVariable.ret()); + } + + private static void copyBigArray(BytecodeBlock body, Variable source, Variable destination, FieldDefinition bigArrayField) + { + body.append(destination.getField(bigArrayField).invoke("ensureCapacity", void.class, source.getField(bigArrayField).invoke("getCapacity", long.class))); + body.append(source.getField(bigArrayField).invoke( + "copyTo", + void.class, + constantLong(0), + destination.getField(bigArrayField), + constantLong(0), + source.getField(bigArrayField).invoke("getCapacity", long.class))); + } + + private static void generateInOutGetType(ClassDefinition definition, SqlTypeBytecodeExpression sqlType) + { + definition.declareMethod(a(PUBLIC), "getType", type(Type.class)) + .getBody() + .append(sqlType.ret()); + } + + private static void generateInOutIsNull(ClassDefinition definition, Function nullGetter) + { + MethodDefinition isNullMethod = definition.declareMethod(a(PUBLIC), "isNull", type(boolean.class)); + isNullMethod.getBody().append(nullGetter.apply(isNullMethod.getScope()).ret()); + } + + private static void generateInOutGetBlockBuilder(ClassDefinition definition, SqlTypeBytecodeExpression sqlType, Function valueGetter) + { + Parameter blockBuilderArg = arg("blockBuilder", BlockBuilder.class); + MethodDefinition getBlockBuilderMethod = definition.declareMethod(a(PUBLIC), "get", type(void.class), blockBuilderArg); + Variable thisVariable = getBlockBuilderMethod.getThis(); + BytecodeBlock body = getBlockBuilderMethod.getBody(); + + body.append(new IfStatement() + .condition(thisVariable.invoke("isNull", boolean.class)) + .ifTrue(blockBuilderArg.invoke("appendNull", BlockBuilder.class).pop()) + .ifFalse(sqlType.writeValue(blockBuilderArg, valueGetter.apply(getBlockBuilderMethod.getScope())))); + body.ret(); + } + + private static void generateInOutSetBlockPosition( + ClassDefinition definition, + SqlTypeBytecodeExpression sqlType, + Function setNullGenerator, + BiFunction setValueGenerator) + { + Parameter blockArg = arg("block", Block.class); + Parameter positionArg = arg("position", int.class); + MethodDefinition setBlockBuilderMethod = definition.declareMethod(a(PUBLIC), "set", type(void.class), blockArg, positionArg); + BytecodeBlock body = setBlockBuilderMethod.getBody(); + + body.append(new IfStatement() + .condition(blockArg.invoke("isNull", boolean.class, positionArg)) + .ifTrue(setNullGenerator.apply(setBlockBuilderMethod.getScope())) + .ifFalse(setValueGenerator.apply(setBlockBuilderMethod.getScope(), sqlType.getValue(blockArg, positionArg)))); + body.ret(); + } + + private static void generateInOutSetInOut( + ClassDefinition definition, + Type type, + Function setNullGenerator, + BiFunction setValueGenerator) + { + Parameter otherState = arg("otherState", InOut.class); + MethodDefinition setter = definition.declareMethod(a(PUBLIC), "set", type(void.class), otherState); + BytecodeBlock body = setter.getBody(); + + body.append(new IfStatement() + .condition(otherState.invoke("isNull", boolean.class)) + .ifTrue(setNullGenerator.apply(setter.getScope())) + .ifFalse(setValueGenerator.apply(setter.getScope(), otherState.cast(InternalDataAccessor.class).invoke(inOutGetterName(type), inOutGetterReturnType(type))))); + body.ret(); + } + + private static void generateInOutGetValue(ClassDefinition definition, Type type, Function valueGetter) + { + MethodDefinition getter = definition.declareMethod(a(PUBLIC), inOutGetterName(type), type(inOutGetterReturnType(type))); + getter.getBody().append(valueGetter.apply(getter.getScope()).ret()); + } + + private static Class inOutGetterReturnType(Type type) + { + Class javaType = type.getJavaType(); + if (javaType.equals(boolean.class)) { + return boolean.class; + } + if (javaType.equals(long.class)) { + return long.class; + } + if (javaType.equals(double.class)) { + return double.class; + } + return Object.class; + } + + private static String inOutGetterName(Type type) + { + Class javaType = type.getJavaType(); + if (javaType.equals(boolean.class)) { + return "getBooleanValue"; + } + if (javaType.equals(long.class)) { + return "getLongValue"; + } + if (javaType.equals(double.class)) { + return "getDoubleValue"; + } + return "getObjectValue"; + } + public static AccumulatorStateFactory generateStateFactory(Class clazz) { return generateStateFactory(clazz, ImmutableMap.of()); @@ -374,6 +711,15 @@ static AccumulatorStateFactory generateStateFact Class singleStateClass = generateSingleStateClass(clazz, fieldTypes, classLoader); Class groupedStateClass = generateGroupedStateClass(clazz, fieldTypes, classLoader); + return generateStateFactory(clazz, singleStateClass, groupedStateClass, classLoader); + } + + private static AccumulatorStateFactory generateStateFactory( + Class clazz, + Class singleStateClass, + Class groupedStateClass, + DynamicClassLoader classLoader) + { ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), makeClassName(clazz.getSimpleName() + "Factory"),