Skip to content

Commit

Permalink
Convert listagg aggregation to annotated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent 5f90789 commit fbf075e
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import io.trino.operator.aggregation.VarcharApproximateMostFrequent;
import io.trino.operator.aggregation.VarianceAggregation;
import io.trino.operator.aggregation.histogram.Histogram;
import io.trino.operator.aggregation.listagg.ListaggAggregationFunction;
import io.trino.operator.aggregation.minmaxby.MaxByNAggregationFunction;
import io.trino.operator.aggregation.minmaxby.MinByNAggregationFunction;
import io.trino.operator.aggregation.multimapagg.MultimapAggregationFunction;
Expand Down Expand Up @@ -269,7 +270,6 @@
import static io.trino.operator.aggregation.RealAverageAggregation.REAL_AVERAGE_AGGREGATION;
import static io.trino.operator.aggregation.ReduceAggregationFunction.REDUCE_AGG;
import static io.trino.operator.aggregation.arrayagg.ArrayAggregationFunction.ARRAY_AGG;
import static io.trino.operator.aggregation.listagg.ListaggAggregationFunction.LISTAGG;
import static io.trino.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY;
import static io.trino.operator.aggregation.minmaxby.MinByAggregationFunction.MIN_BY;
import static io.trino.operator.scalar.ArrayConcatFunction.ARRAY_CONCAT_FUNCTION;
Expand Down Expand Up @@ -523,7 +523,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.function(ARRAY_CONCAT_FUNCTION)
.functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, JSON_TO_ARRAY, JSON_STRING_TO_ARRAY)
.function(ARRAY_AGG)
.function(LISTAGG)
.aggregates(ListaggAggregationFunction.class)
.functions(new MapSubscriptOperator())
.functions(MAP_CONSTRUCTOR, JSON_TO_MAP, JSON_STRING_TO_MAP)
.functions(new MapAggregationFunction(blockTypeOperators), new MapUnionAggregation(blockTypeOperators))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import io.trino.operator.aggregation.AbstractGroupCollectionAggregationState;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.type.Type;

import static io.trino.spi.type.VarcharType.VARCHAR;

public final class GroupListaggAggregationState
extends AbstractGroupCollectionAggregationState<ListaggAggregationStateConsumer>
Expand All @@ -32,9 +33,9 @@ public final class GroupListaggAggregationState
private Slice overflowFiller;
private boolean showOverflowEntryCount;

public GroupListaggAggregationState(Type valueType)
public GroupListaggAggregationState()
{
super(PageBuilder.withMaxPageSize(MAX_BLOCK_SIZE, ImmutableList.of(valueType)));
super(PageBuilder.withMaxPageSize(MAX_BLOCK_SIZE, ImmutableList.of(VARCHAR)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,113 +14,44 @@
package io.trino.operator.aggregation.listagg;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE;
import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;

import static io.trino.spi.StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.TypeSignature.arrayType;
import static io.trino.spi.type.TypeSignatureParameter.typeVariable;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.String.format;

public class ListaggAggregationFunction
extends SqlAggregationFunction
@AggregationFunction(value = "listagg", isOrderSensitive = true)
@Description("concatenates the input values with the specified separator")
public final class ListaggAggregationFunction
{
public static final ListaggAggregationFunction LISTAGG = new ListaggAggregationFunction();
public static final String NAME = "listagg";
private static final MethodHandle INPUT_FUNCTION = methodHandle(ListaggAggregationFunction.class, "input", Type.class, ListaggAggregationState.class, Block.class, Slice.class, boolean.class, Slice.class, boolean.class, int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(ListaggAggregationFunction.class, "combine", Type.class, ListaggAggregationState.class, ListaggAggregationState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ListaggAggregationFunction.class, "output", Type.class, ListaggAggregationState.class, BlockBuilder.class);

private static final int MAX_OUTPUT_LENGTH = DEFAULT_MAX_PAGE_SIZE_IN_BYTES;
private static final int MAX_OVERFLOW_FILLER_LENGTH = 65_536;

private ListaggAggregationFunction()
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.returnType(VARCHAR)
.argumentType(new TypeSignature(StandardTypes.VARCHAR, typeVariable("v")))
.argumentType(new TypeSignature(StandardTypes.VARCHAR, typeVariable("d")))
.argumentType(BOOLEAN)
.argumentType(new TypeSignature(StandardTypes.VARCHAR, typeVariable("f")))
.argumentType(BOOLEAN)
.build())
.nullable()
.description("concatenates the input values with the specified separator")
.build(),
AggregationFunctionMetadata.builder()
.orderSensitive()
.intermediateType(VARCHAR.getTypeSignature())
.intermediateType(BOOLEAN.getTypeSignature())
.intermediateType(VARCHAR.getTypeSignature())
.intermediateType(BOOLEAN.getTypeSignature())
.intermediateType(arrayType(VARCHAR.getTypeSignature()))
.build());
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
Type type = VARCHAR;
AccumulatorStateSerializer<ListaggAggregationState> stateSerializer = new ListaggAggregationStateSerializer(type);
AccumulatorStateFactory<ListaggAggregationState> stateFactory = new ListaggAggregationStateFactory(type);

MethodHandle inputFunction = normalizeInputMethod(
INPUT_FUNCTION.bindTo(type),
boundSignature,
STATE,
NULLABLE_BLOCK_INPUT_CHANNEL,
INPUT_CHANNEL,
INPUT_CHANNEL,
INPUT_CHANNEL,
INPUT_CHANNEL,
BLOCK_INDEX);
MethodHandle combineFunction = COMBINE_FUNCTION.bindTo(type);
MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(type);

return new AggregationMetadata(
inputFunction,
Optional.empty(),
Optional.of(combineFunction),
outputFunction,
ImmutableList.of(new AccumulatorStateDescriptor<>(
ListaggAggregationState.class,
stateSerializer,
stateFactory)));
}

public static void input(Type type, ListaggAggregationState state, Block value, Slice separator, boolean overflowError, Slice overflowFiller, boolean showOverflowEntryCount, int position)
private ListaggAggregationFunction() {}

@InputFunction
public static void input(
@AggregationState ListaggAggregationState state,
@BlockPosition @SqlType("VARCHAR") Block value,
@SqlType("VARCHAR") Slice separator,
@SqlType("BOOLEAN") boolean overflowError,
@SqlType("VARCHAR") Slice overflowFiller,
@SqlType("BOOLEAN") boolean showOverflowEntryCount,
@BlockIndex int position)
{
if (state.isEmpty()) {
if (overflowFiller.length() > MAX_OVERFLOW_FILLER_LENGTH) {
Expand All @@ -136,7 +67,8 @@ public static void input(Type type, ListaggAggregationState state, Block value,
state.add(value, position);
}

public static void combine(Type type, ListaggAggregationState state, ListaggAggregationState otherState)
@CombineFunction
public static void combine(@AggregationState ListaggAggregationState state, @AggregationState ListaggAggregationState otherState)
{
Slice previousSeparator = state.getSeparator();
if (previousSeparator == null) {
Expand All @@ -149,7 +81,8 @@ public static void combine(Type type, ListaggAggregationState state, ListaggAggr
state.merge(otherState);
}

public static void output(Type type, ListaggAggregationState state, BlockBuilder out)
@OutputFunction("VARCHAR")
public static void output(ListaggAggregationState state, BlockBuilder out)
{
if (state.isEmpty()) {
out.appendNull();
Expand All @@ -160,7 +93,7 @@ public static void output(Type type, ListaggAggregationState state, BlockBuilder
}

@VisibleForTesting
protected static void outputState(ListaggAggregationState state, BlockBuilder out, int maxOutputLength)
public static void outputState(ListaggAggregationState state, BlockBuilder out, int maxOutputLength)
{
Slice separator = state.getSeparator();
int separatorLength = separator.length();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import io.airlift.slice.Slice;
import io.trino.spi.block.Block;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateMetadata;

@AccumulatorStateMetadata(
stateFactoryClass = ListaggAggregationStateFactory.class,
stateSerializerClass = ListaggAggregationStateSerializer.class)
public interface ListaggAggregationState
extends AccumulatorState
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,19 @@
package io.trino.operator.aggregation.listagg;

import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.type.Type;

public class ListaggAggregationStateFactory
implements AccumulatorStateFactory<ListaggAggregationState>
{
private final Type type;

public ListaggAggregationStateFactory(Type type)
{
this.type = type;
}

@Override
public ListaggAggregationState createSingleState()
{
return new SingleListaggAggregationState(type);
return new SingleListaggAggregationState();
}

@Override
public ListaggAggregationState createGroupedState()
{
return new GroupListaggAggregationState(type);
return new GroupListaggAggregationState();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,12 @@
public class ListaggAggregationStateSerializer
implements AccumulatorStateSerializer<ListaggAggregationState>
{
private final Type elementType;
private final Type arrayType;
private final Type serializedType;

public ListaggAggregationStateSerializer(Type elementType)
public ListaggAggregationStateSerializer()
{
this.elementType = elementType;
this.arrayType = new ArrayType(elementType);
this.arrayType = new ArrayType(VARCHAR);
this.serializedType = RowType.anonymous(ImmutableList.of(VARCHAR, BOOLEAN, VARCHAR, BOOLEAN, arrayType));
}

Expand All @@ -64,7 +62,7 @@ public void serialize(ListaggAggregationState state, BlockBuilder out)

BlockBuilder stateElementsBlockBuilder = rowBlockBuilder.beginBlockEntry();
state.forEach((block, position) -> {
elementType.appendTo(block, position, stateElementsBlockBuilder);
VARCHAR.appendTo(block, position, stateElementsBlockBuilder);
return true;
});
rowBlockBuilder.closeEntry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
import io.airlift.slice.Slice;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.Type;
import org.openjdk.jol.info.ClassLayout;

import static com.google.common.base.Verify.verify;
import static java.util.Objects.requireNonNull;
import static io.trino.spi.type.VarcharType.VARCHAR;

public class SingleListaggAggregationState
implements ListaggAggregationState
Expand All @@ -31,12 +30,6 @@ public class SingleListaggAggregationState
private boolean overflowError;
private Slice overflowFiller;
private boolean showOverflowEntryCount;
private final Type type;

public SingleListaggAggregationState(Type type)
{
this.type = requireNonNull(type, "type is null");
}

@Override
public long getEstimatedSize()
Expand Down Expand Up @@ -100,9 +93,9 @@ public boolean showOverflowEntryCount()
public void add(Block block, int position)
{
if (blockBuilder == null) {
blockBuilder = type.createBlockBuilder(null, 16);
blockBuilder = VARCHAR.createBlockBuilder(null, 16);
}
type.appendTo(block, position, blockBuilder);
VARCHAR.appendTo(block, position, blockBuilder);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ public class TestListaggAggregationFunction
@Test
public void testInputEmptyState()
{
SingleListaggAggregationState state = new SingleListaggAggregationState(VARCHAR);
SingleListaggAggregationState state = new SingleListaggAggregationState();

String s = "value1";
Block value = createStringsBlock(s);
Slice separator = utf8Slice(",");
Slice overflowFiller = utf8Slice("...");
ListaggAggregationFunction.input(VARCHAR,
ListaggAggregationFunction.input(
state,
value,
separator,
Expand Down Expand Up @@ -84,9 +84,9 @@ public void testInputOverflowOverflowFillerTooLong()
{
String overflowFillerTooLong = StringUtils.repeat(".", 65_537);

SingleListaggAggregationState state = new SingleListaggAggregationState(VARCHAR);
SingleListaggAggregationState state = new SingleListaggAggregationState();

assertThatThrownBy(() -> ListaggAggregationFunction.input(VARCHAR,
assertThatThrownBy(() -> ListaggAggregationFunction.input(
state,
createStringsBlock("value1"),
utf8Slice(","),
Expand Down Expand Up @@ -252,7 +252,7 @@ private static String getOutputStateOnlyValue(SingleListaggAggregationState stat

private static SingleListaggAggregationState createListaggAggregationState(String separator, boolean overflowError, String overflowFiller, boolean showOverflowEntryCount, String... values)
{
SingleListaggAggregationState state = new SingleListaggAggregationState(VARCHAR);
SingleListaggAggregationState state = new SingleListaggAggregationState();
state.setSeparator(utf8Slice(separator));
state.setOverflowError(overflowError);
state.setOverflowFiller(utf8Slice(overflowFiller));
Expand Down

0 comments on commit fbf075e

Please sign in to comment.