Skip to content

Commit

Permalink
Require aggregation metadata at construction
Browse files Browse the repository at this point in the history
Require full aggregation function metadata during construction of aggregation.
Change intermediate types to allow generic type signatures.
  • Loading branch information
dain committed Oct 9, 2021
1 parent 8016601 commit aec7351
Show file tree
Hide file tree
Showing 26 changed files with 156 additions and 268 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,39 @@
*/
package io.trino.metadata;

import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;

import java.util.Arrays;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public class AggregationFunctionMetadata
{
private final boolean orderSensitive;
private final Optional<TypeSignature> intermediateType;

public AggregationFunctionMetadata(boolean orderSensitive, TypeSignature... intermediateTypes)
{
this.orderSensitive = orderSensitive;

if (intermediateTypes.length == 0) {
intermediateType = Optional.empty();
}
else if (intermediateTypes.length == 1) {
intermediateType = Optional.of(intermediateTypes[0]);
}
else {
intermediateType = Optional.of(new TypeSignature(StandardTypes.ROW, Arrays.stream(intermediateTypes)
.map(TypeSignatureParameter::anonymousField)
.collect(toImmutableList())));
}
}

public AggregationFunctionMetadata(boolean orderSensitive, Optional<TypeSignature> intermediateType)
{
this.orderSensitive = orderSensitive;
Expand All @@ -36,6 +57,11 @@ public boolean isOrderSensitive()
return orderSensitive;
}

public boolean isDecomposable()
{
return intermediateType.isPresent();
}

public Optional<TypeSignature> getIntermediateType()
{
return intermediateType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -820,13 +820,13 @@ public FunctionMetadata get(FunctionId functionId)
return functions.get(functionId).getFunctionMetadata();
}

public AggregationFunctionMetadata getAggregationFunctionMetadata(FunctionBinding functionBinding)
public AggregationFunctionMetadata getAggregationFunctionMetadata(FunctionId functionId)
{
SqlFunction function = functions.get(functionBinding.getFunctionId());
checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", functionBinding.getBoundSignature());
SqlFunction function = functions.get(functionId);
checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", function.getFunctionMetadata().getSignature());

SqlAggregationFunction aggregationFunction = (SqlAggregationFunction) function;
return aggregationFunction.getAggregationMetadata(functionBinding);
return aggregationFunction.getAggregationMetadata();
}

public WindowFunctionSupplier getWindowFunctionImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2397,7 +2397,13 @@ public FunctionMetadata getFunctionMetadata(ResolvedFunction resolvedFunction)
@Override
public AggregationFunctionMetadata getAggregationFunctionMetadata(ResolvedFunction resolvedFunction)
{
return functions.getAggregationFunctionMetadata(toFunctionBinding(resolvedFunction));
AggregationFunctionMetadata aggregationFunctionMetadata = functions.getAggregationFunctionMetadata(resolvedFunction.getFunctionId());
return new AggregationFunctionMetadata(
aggregationFunctionMetadata.isOrderSensitive(),
aggregationFunctionMetadata.getIntermediateType().map(typeSignature -> {
FunctionBinding functionBinding = toFunctionBinding(resolvedFunction);
return applyBoundVariables(typeSignature, functionBinding);
}));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,26 @@
import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.AggregationFromAnnotationsParser;
import io.trino.operator.aggregation.InternalAggregationFunction;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;

public abstract class SqlAggregationFunction
implements SqlFunction
{
private final FunctionMetadata functionMetadata;
private final boolean orderSensitive;
private final boolean decomposable;
private final AggregationFunctionMetadata aggregationFunctionMetadata;

public static List<SqlAggregationFunction> createFunctionsByAnnotations(Class<?> aggregationDefinition)
{
return ImmutableList.copyOf(AggregationFromAnnotationsParser.parseFunctionDefinitions(aggregationDefinition));
}

protected SqlAggregationFunction(FunctionMetadata functionMetadata, boolean decomposable, boolean orderSensitive)
public SqlAggregationFunction(FunctionMetadata functionMetadata, AggregationFunctionMetadata aggregationFunctionMetadata)
{
this.functionMetadata = requireNonNull(functionMetadata, "functionMetadata is null");
checkArgument(functionMetadata.isDeterministic(), "Aggregation function must be deterministic");
this.orderSensitive = orderSensitive;
this.decomposable = decomposable;
this.aggregationFunctionMetadata = requireNonNull(aggregationFunctionMetadata, "aggregationFunctionMetadata is null");
}

@Override
Expand All @@ -54,28 +44,9 @@ public FunctionMetadata getFunctionMetadata()
return functionMetadata;
}

public AggregationFunctionMetadata getAggregationMetadata(FunctionBinding functionBinding)
public AggregationFunctionMetadata getAggregationMetadata()
{
if (!decomposable) {
return new AggregationFunctionMetadata(orderSensitive, Optional.empty());
}

List<TypeSignature> intermediateTypes = getIntermediateTypes(functionBinding);
TypeSignature intermediateType;
if (intermediateTypes.size() == 1) {
intermediateType = getOnlyElement(intermediateTypes);
}
else {
intermediateType = new TypeSignature(StandardTypes.ROW, intermediateTypes.stream()
.map(TypeSignatureParameter::anonymousField)
.collect(toImmutableList()));
}
return new AggregationFunctionMetadata(orderSensitive, Optional.of(intermediateType));
}

protected List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
throw new UnsupportedOperationException();
return aggregationFunctionMetadata;
}

public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionDependencies;
Expand All @@ -32,9 +33,6 @@
import io.trino.operator.aggregation.state.GenericDoubleStateSerializer;
import io.trino.operator.aggregation.state.GenericLongState;
import io.trino.operator.aggregation.state.GenericLongStateSerializer;
import io.trino.operator.aggregation.state.NullableBooleanState;
import io.trino.operator.aggregation.state.NullableDoubleState;
import io.trino.operator.aggregation.state.NullableLongState;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
Expand Down Expand Up @@ -104,8 +102,9 @@ protected AbstractMinMaxAggregationFunction(String name, boolean min, String des
true,
description,
AGGREGATE),
true,
false);
new AggregationFunctionMetadata(
false,
new TypeSignature("E")));
this.min = min;
}

Expand All @@ -115,23 +114,6 @@ public FunctionDependencyDeclaration getFunctionDependencies()
return getMinMaxCompareFunctionDependencies(new TypeSignature("E"), min);
}

@Override
public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
Type type = functionBinding.getTypeVariable("E");
if (type.getJavaType() == long.class) {
return ImmutableList.of(StateCompiler.getSerializedType(NullableLongState.class).getTypeSignature());
}
if (type.getJavaType() == double.class) {
return ImmutableList.of(StateCompiler.getSerializedType(NullableDoubleState.class).getTypeSignature());
}
if (type.getJavaType() == boolean.class) {
return ImmutableList.of(StateCompiler.getSerializedType(NullableBooleanState.class).getTypeSignature());
}
// native container type is Slice or Block
return ImmutableList.of(new BlockPositionStateSerializer(type).getSerializedType().getTypeSignature());
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionDependencies;
Expand All @@ -30,7 +31,6 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.util.MinMaxCompare;
Expand Down Expand Up @@ -87,18 +87,13 @@ protected AbstractMinMaxNAggregationFunction(String name, boolean min, String de
true,
description,
AGGREGATE),
true,
false);
new AggregationFunctionMetadata(
false,
BIGINT.getTypeSignature(),
TypeSignature.arrayType(new TypeSignature("E"))));
this.min = min;
}

@Override
public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
Type type = functionBinding.getTypeVariable("E");
return ImmutableList.of(RowType.anonymous(ImmutableList.of(BIGINT, type)).getTypeSignature());
}

@Override
public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionMetadata;
Expand All @@ -29,9 +30,6 @@
import io.trino.operator.aggregation.state.GenericDoubleStateSerializer;
import io.trino.operator.aggregation.state.GenericLongState;
import io.trino.operator.aggregation.state.GenericLongStateSerializer;
import io.trino.operator.aggregation.state.NullableBooleanState;
import io.trino.operator.aggregation.state.NullableDoubleState;
import io.trino.operator.aggregation.state.NullableLongState;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
Expand Down Expand Up @@ -92,25 +90,9 @@ protected ArbitraryAggregationFunction()
true,
"Return an arbitrary non-null input value",
AGGREGATE),
true,
false);
}

@Override
public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
Type type = functionBinding.getTypeVariable("T");
if (type.getJavaType() == long.class) {
return ImmutableList.of(StateCompiler.getSerializedType(NullableLongState.class).getTypeSignature());
}
if (type.getJavaType() == double.class) {
return ImmutableList.of(StateCompiler.getSerializedType(NullableDoubleState.class).getTypeSignature());
}
if (type.getJavaType() == boolean.class) {
return ImmutableList.of(StateCompiler.getSerializedType(NullableBooleanState.class).getTypeSignature());
}
// native container type is Slice or Block
return ImmutableList.of(new BlockPositionStateSerializer(type).getSerializedType().getTypeSignature());
new AggregationFunctionMetadata(
false,
new TypeSignature("T")));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionMetadata;
Expand Down Expand Up @@ -44,6 +45,7 @@
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -77,17 +79,12 @@ public ChecksumAggregationFunction(BlockTypeOperators blockTypeOperators)
true,
"Checksum of the given values",
AGGREGATE),
true,
false);
new AggregationFunctionMetadata(
false,
BIGINT.getTypeSignature()));
this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null");
}

@Override
public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
return ImmutableList.of(StateCompiler.getSerializedType(NullableLongState.class).getTypeSignature());
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionMetadata;
Expand Down Expand Up @@ -72,14 +73,9 @@ public CountColumn()
true,
"Counts the non-null values",
AGGREGATE),
true,
false);
}

@Override
public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
return ImmutableList.of(StateCompiler.getSerializedType(LongState.class).getTypeSignature());
new AggregationFunctionMetadata(
false,
BIGINT.getTypeSignature()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.slice.Slice;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionMetadata;
Expand Down Expand Up @@ -57,6 +58,7 @@
import static io.trino.spi.type.Decimals.writeShortDecimal;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.util.Reflection.methodHandle;
import static java.math.BigDecimal.ROUND_HALF_UP;

Expand Down Expand Up @@ -95,14 +97,9 @@ public DecimalAverageAggregation()
true,
"Calculates the average value",
AGGREGATE),
true,
false);
}

@Override
public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
return ImmutableList.of(new LongDecimalWithOverflowAndLongStateSerializer().getSerializedType().getTypeSignature());
new AggregationFunctionMetadata(
false,
VARBINARY.getTypeSignature()));
}

@Override
Expand Down
Loading

0 comments on commit aec7351

Please sign in to comment.