Skip to content

Commit

Permalink
Convert map_agg and map_union aggregations to annotated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent 79ef11c commit 9d85654
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.aggregates(ListaggAggregationFunction.class)
.functions(new MapSubscriptOperator())
.functions(MAP_CONSTRUCTOR, JSON_TO_MAP, JSON_STRING_TO_MAP)
.functions(new MapAggregationFunction(blockTypeOperators), new MapUnionAggregation(blockTypeOperators))
.aggregates(MapAggregationFunction.class)
.aggregates(MapUnionAggregation.class)
.function(REDUCE_AGG)
.function(new MultimapAggregationFunction(blockTypeOperators))
.functions(DECIMAL_TO_VARCHAR_CAST, DECIMAL_TO_INTEGER_CAST, DECIMAL_TO_BIGINT_CAST, DECIMAL_TO_DOUBLE_CAST, DECIMAL_TO_REAL_CAST, DECIMAL_TO_BOOLEAN_CAST, DECIMAL_TO_TINYINT_CAST, DECIMAL_TO_SMALLINT_CAST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,106 +13,56 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.KeyValuePairStateSerializer;
import io.trino.operator.aggregation.state.KeyValuePairsState;
import io.trino.operator.aggregation.state.KeyValuePairsStateFactory;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.MapType;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Convention;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.OperatorType;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.Optional;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;

import static io.trino.spi.type.TypeSignature.mapType;
import static io.trino.util.Reflection.methodHandle;
import static java.util.Objects.requireNonNull;

public class MapAggregationFunction
extends SqlAggregationFunction
@AggregationFunction(value = "map_agg", isOrderSensitive = true)
@Description("Aggregates all the rows (key/value pairs) into a single map")
public final class MapAggregationFunction
{
public static final String NAME = "map_agg";
private static final MethodHandle INPUT_FUNCTION = methodHandle(
MapAggregationFunction.class,
"input",
Type.class,
BlockPositionEqual.class,
BlockPositionHashCode.class,
Type.class,
KeyValuePairsState.class,
Block.class,
Block.class,
int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(MapAggregationFunction.class, "combine", KeyValuePairsState.class, KeyValuePairsState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(MapAggregationFunction.class, "output", KeyValuePairsState.class, BlockBuilder.class);

private final BlockTypeOperators blockTypeOperators;

public MapAggregationFunction(BlockTypeOperators blockTypeOperators)
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.comparableTypeParameter("K")
.typeVariable("V")
.returnType(mapType(new TypeSignature("K"), new TypeSignature("V")))
.argumentType(new TypeSignature("K"))
.argumentType(new TypeSignature("V"))
.build())
.argumentNullability(false, true)
.description("Aggregates all the rows (key/value pairs) into a single map")
.build(),
AggregationFunctionMetadata.builder()
.orderSensitive()
.intermediateType(mapType(new TypeSignature("K"), new TypeSignature("V")))
.build());
this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null");
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
MapType outputType = (MapType) boundSignature.getReturnType();
Type keyType = outputType.getKeyType();
BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType);
BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType);

Type valueType = outputType.getValueType();
KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(outputType, keyEqual, keyHashCode);

return new AggregationMetadata(
MethodHandles.insertArguments(INPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType),
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
OUTPUT_FUNCTION,
ImmutableList.of(new AccumulatorStateDescriptor<>(
KeyValuePairsState.class,
stateSerializer,
new KeyValuePairsStateFactory(keyType, valueType))));
}
private MapAggregationFunction() {}

@InputFunction
@TypeParameter("K")
@TypeParameter("V")
public static void input(
Type keyType,
BlockPositionEqual keyEqual,
BlockPositionHashCode keyHashCode,
Type valueType,
KeyValuePairsState state,
Block key,
Block value,
int position)
@TypeParameter("K") Type keyType,
@OperatorDependency(
operator = OperatorType.EQUAL,
argumentTypes = {"K", "K"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN))
BlockPositionEqual keyEqual,
@OperatorDependency(
operator = OperatorType.HASH_CODE,
argumentTypes = "K",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL))
BlockPositionHashCode keyHashCode,
@TypeParameter("V") Type valueType,
@AggregationState({"K", "V"}) KeyValuePairsState state,
@BlockPosition @SqlType("K") Block key,
@NullablePosition @BlockPosition @SqlType("V") Block value,
@BlockIndex int position)
{
KeyValuePairs pairs = state.get();
if (pairs == null) {
Expand All @@ -125,7 +75,10 @@ public static void input(
state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize);
}

public static void combine(KeyValuePairsState state, KeyValuePairsState otherState)
@CombineFunction
public static void combine(
@AggregationState({"K", "V"}) KeyValuePairsState state,
@AggregationState({"K", "V"}) KeyValuePairsState otherState)
{
if (state.get() != null && otherState.get() != null) {
Block keys = otherState.get().getKeys();
Expand All @@ -142,7 +95,8 @@ else if (state.get() == null) {
}
}

public static void output(KeyValuePairsState state, BlockBuilder out)
@OutputFunction("map(K, V)")
public static void output(@AggregationState({"K", "V"}) KeyValuePairsState state, BlockBuilder out)
{
KeyValuePairs pairs = state.get();
if (pairs == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,105 +13,52 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.KeyValuePairStateSerializer;
import io.trino.operator.aggregation.state.KeyValuePairsState;
import io.trino.operator.aggregation.state.KeyValuePairsStateFactory;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.MapType;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Convention;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.OperatorType;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.Optional;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;

import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod;
import static io.trino.spi.type.TypeSignature.mapType;
import static io.trino.util.Reflection.methodHandle;
import static java.util.Objects.requireNonNull;

public class MapUnionAggregation
extends SqlAggregationFunction
@AggregationFunction("map_union")
@Description("Aggregate all the maps into a single map")
public final class MapUnionAggregation
{
public static final String NAME = "map_union";
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(MapUnionAggregation.class, "output", KeyValuePairsState.class, BlockBuilder.class);
private static final MethodHandle INPUT_FUNCTION = methodHandle(MapUnionAggregation.class,
"input",
Type.class,
BlockPositionEqual.class,
BlockPositionHashCode.class,
Type.class,
KeyValuePairsState.class,
Block.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(MapUnionAggregation.class, "combine", KeyValuePairsState.class, KeyValuePairsState.class);

private final BlockTypeOperators blockTypeOperators;

public MapUnionAggregation(BlockTypeOperators blockTypeOperators)
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.comparableTypeParameter("K")
.typeVariable("V")
.returnType(mapType(new TypeSignature("K"), new TypeSignature("V")))
.argumentType(mapType(new TypeSignature("K"), new TypeSignature("V")))
.build())
.description("Aggregate all the maps into a single map")
.build(),
AggregationFunctionMetadata.builder()
.intermediateType(mapType(new TypeSignature("K"), new TypeSignature("V")))
.build());
this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null");
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
MapType outputType = (MapType) boundSignature.getReturnType();
Type keyType = outputType.getKeyType();
BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType);
BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType);

Type valueType = outputType.getValueType();

KeyValuePairStateSerializer stateSerializer = new KeyValuePairStateSerializer(outputType, keyEqual, keyHashCode);

MethodHandle inputFunction = MethodHandles.insertArguments(INPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType);
inputFunction = normalizeInputMethod(inputFunction, boundSignature, STATE, INPUT_CHANNEL);

return new AggregationMetadata(
inputFunction,
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
OUTPUT_FUNCTION,
ImmutableList.of(new AccumulatorStateDescriptor<>(
KeyValuePairsState.class,
stateSerializer,
new KeyValuePairsStateFactory(keyType, valueType))));
}
private MapUnionAggregation() {}

@InputFunction
@TypeParameter("K")
@TypeParameter("V")
public static void input(
Type keyType,
BlockPositionEqual keyEqual,
BlockPositionHashCode keyHashCode,
Type valueType,
KeyValuePairsState state,
Block value)
@TypeParameter("K") Type keyType,
@OperatorDependency(
operator = OperatorType.EQUAL,
argumentTypes = {"K", "K"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN))
BlockPositionEqual keyEqual,
@OperatorDependency(
operator = OperatorType.HASH_CODE,
argumentTypes = "K",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL))
BlockPositionHashCode keyHashCode,
@TypeParameter("V") Type valueType,
@AggregationState({"K", "V"}) KeyValuePairsState state,
@SqlType("map(K,V)") Block value)
{
KeyValuePairs pairs = state.get();
if (pairs == null) {
Expand All @@ -126,12 +73,16 @@ public static void input(
state.addMemoryUsage(pairs.estimatedInMemorySize() - startSize);
}

public static void combine(KeyValuePairsState state, KeyValuePairsState otherState)
@CombineFunction
public static void combine(
@AggregationState({"K", "V"}) KeyValuePairsState state,
@AggregationState({"K", "V"}) KeyValuePairsState otherState)
{
MapAggregationFunction.combine(state, otherState);
}

public static void output(KeyValuePairsState state, BlockBuilder out)
@OutputFunction("map(K, V)")
public static void output(@AggregationState({"K", "V"}) KeyValuePairsState state, BlockBuilder out)
{
MapAggregationFunction.output(state, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,42 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.MapType;
import io.trino.spi.function.Convention;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.OperatorType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;

import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static java.util.Objects.requireNonNull;

public class KeyValuePairStateSerializer
implements AccumulatorStateSerializer<KeyValuePairsState>
{
private final MapType mapType;
private final Type mapType;
private final BlockPositionEqual keyEqual;
private final BlockPositionHashCode keyHashCode;

public KeyValuePairStateSerializer(MapType mapType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode)
public KeyValuePairStateSerializer(
@TypeParameter("MAP(K, V)") Type mapType,
@OperatorDependency(
operator = OperatorType.EQUAL,
argumentTypes = {"K", "K"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN))
BlockPositionEqual keyEqual,
@OperatorDependency(
operator = OperatorType.HASH_CODE,
argumentTypes = "K",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL))
BlockPositionHashCode keyHashCode)
{
this.mapType = mapType;
this.keyEqual = keyEqual;
this.keyHashCode = keyHashCode;
this.mapType = requireNonNull(mapType, "mapType is null");
this.keyEqual = requireNonNull(keyEqual, "keyEqual is null");
this.keyHashCode = requireNonNull(keyHashCode, "keyHashCode is null");
}

@Override
Expand All @@ -56,6 +75,6 @@ public void serialize(KeyValuePairsState state, BlockBuilder out)
@Override
public void deserialize(Block block, int index, KeyValuePairsState state)
{
state.set(new KeyValuePairs(mapType.getObject(block, index), state.getKeyType(), keyEqual, keyHashCode, state.getValueType()));
state.set(new KeyValuePairs((Block) mapType.getObject(block, index), state.getKeyType(), keyEqual, keyHashCode, state.getValueType()));
}
}
Loading

0 comments on commit 9d85654

Please sign in to comment.