Skip to content

Commit

Permalink
Convert min/max by N aggregation to MinMaxCompare helper
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Oct 9, 2021
1 parent 2e96304 commit fc3af01
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
*/
package io.trino.operator.aggregation;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionComparison;
import org.openjdk.jol.info.ClassLayout;

import java.lang.invoke.MethodHandle;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.trino.spi.type.BigintType.BIGINT;
Expand All @@ -34,7 +36,7 @@ public class TypedKeyValueHeap
private static final int COMPACT_THRESHOLD_BYTES = 32768;
private static final int COMPACT_THRESHOLD_RATIO = 3; // when 2/3 of elements in keyBlockBuilder is unreferenced, do compact

private final BlockPositionComparison keyComparison;
private final MethodHandle keyGreaterThan;
private final Type keyType;
private final Type valueType;
private final int capacity;
Expand All @@ -44,9 +46,9 @@ public class TypedKeyValueHeap
private BlockBuilder keyBlockBuilder;
private BlockBuilder valueBlockBuilder;

public TypedKeyValueHeap(BlockPositionComparison keyComparison, Type keyType, Type valueType, int capacity)
public TypedKeyValueHeap(MethodHandle keyGreaterThan, Type keyType, Type valueType, int capacity)
{
this.keyComparison = keyComparison;
this.keyGreaterThan = keyGreaterThan;
this.keyType = keyType;
this.valueType = valueType;
this.capacity = capacity;
Expand Down Expand Up @@ -95,12 +97,12 @@ public void serialize(BlockBuilder out)
out.closeEntry();
}

public static TypedKeyValueHeap deserialize(Block block, Type keyType, Type valueType, BlockPositionComparison comparison)
public static TypedKeyValueHeap deserialize(Block block, Type keyType, Type valueType, MethodHandle keyComparisonOperator)
{
int capacity = toIntExact(BIGINT.getLong(block, 0));
Block keysBlock = new ArrayType(keyType).getObject(block, 1);
Block valuesBlock = new ArrayType(valueType).getObject(block, 2);
TypedKeyValueHeap heap = new TypedKeyValueHeap(comparison, keyType, valueType, capacity);
TypedKeyValueHeap heap = new TypedKeyValueHeap(keyComparisonOperator, keyType, valueType, capacity);
heap.addAll(keysBlock, valuesBlock);
return heap;
}
Expand Down Expand Up @@ -129,7 +131,7 @@ public void add(Block keyBlock, Block valueBlock, int position)
{
checkArgument(!keyBlock.isNull(position));
if (positionCount == capacity) {
if (keyComparison.compare(keyBlockBuilder, heapIndex[0], keyBlock, position) >= 0) {
if (keyGreaterThanOrEqual(keyBlockBuilder, heapIndex[0], keyBlock, position)) {
return; // and new element is not larger than heap top: do not add
}
heapIndex[0] = keyBlockBuilder.getPositionCount();
Expand Down Expand Up @@ -173,9 +175,9 @@ private void siftDown()
smallerChildPosition = leftPosition;
}
else {
smallerChildPosition = keyComparison.compare(keyBlockBuilder, heapIndex[leftPosition], keyBlockBuilder, heapIndex[rightPosition]) >= 0 ? rightPosition : leftPosition;
smallerChildPosition = keyGreaterThanOrEqual(keyBlockBuilder, heapIndex[leftPosition], keyBlockBuilder, heapIndex[rightPosition]) ? rightPosition : leftPosition;
}
if (keyComparison.compare(keyBlockBuilder, heapIndex[smallerChildPosition], keyBlockBuilder, heapIndex[position]) >= 0) {
if (keyGreaterThanOrEqual(keyBlockBuilder, heapIndex[smallerChildPosition], keyBlockBuilder, heapIndex[position])) {
break; // child is larger or equal
}
int swapTemp = heapIndex[position];
Expand All @@ -190,7 +192,7 @@ private void siftUp()
int position = positionCount - 1;
while (position != 0) {
int parentPosition = (position - 1) / 2;
if (keyComparison.compare(keyBlockBuilder, heapIndex[position], keyBlockBuilder, heapIndex[parentPosition]) >= 0) {
if (keyGreaterThanOrEqual(keyBlockBuilder, heapIndex[position], keyBlockBuilder, heapIndex[parentPosition])) {
break; // child is larger or equal
}
int swapTemp = heapIndex[position];
Expand Down Expand Up @@ -218,4 +220,18 @@ private void compactIfNecessary()
keyBlockBuilder = newHeapKeyBlockBuilder;
valueBlockBuilder = newHeapValueBlockBuilder;
}

private boolean keyGreaterThanOrEqual(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition)
{
try {
// Swap the argument order to get a less than operator, and negate the result to get greater than or equals.
// Note: the keyGreaterThan operator is based comparison, and specifically is not a pure greater than operator.
// This means negation of the result is safe for unordered values.
return !((boolean) keyGreaterThan.invokeExact(rightBlock, rightPosition, leftBlock, leftPosition));
}
catch (Throwable throwable) {
Throwables.throwIfUnchecked(throwable);
throw new RuntimeException(throwable);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionDependencies;
import io.trino.metadata.FunctionDependencyDeclaration;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
Expand All @@ -33,12 +35,11 @@
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators.BlockPositionComparison;
import io.trino.util.MinMaxCompare;

import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.metadata.FunctionKind.AGGREGATE;
Expand All @@ -51,24 +52,28 @@
import static io.trino.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static io.trino.operator.aggregation.AggregationUtils.generateAggregationName;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.simpleConvention;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.util.Failures.checkCondition;
import static io.trino.util.MinMaxCompare.getMinMaxCompare;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public abstract class AbstractMinMaxByNAggregationFunction
extends SqlAggregationFunction
{
private static final MethodHandle INPUT_FUNCTION = methodHandle(AbstractMinMaxByNAggregationFunction.class, "input", BlockPositionComparison.class, Type.class, Type.class, MinMaxByNState.class, Block.class, Block.class, int.class, long.class);
private static final MethodHandle INPUT_FUNCTION = methodHandle(AbstractMinMaxByNAggregationFunction.class, "input", MethodHandle.class, Type.class, Type.class, MinMaxByNState.class, Block.class, Block.class, int.class, long.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(AbstractMinMaxByNAggregationFunction.class, "combine", MinMaxByNState.class, MinMaxByNState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(AbstractMinMaxByNAggregationFunction.class, "output", ArrayType.class, MinMaxByNState.class, BlockBuilder.class);
private static final long MAX_NUMBER_OF_VALUES = 10_000;

private final String name;
private final Function<Type, BlockPositionComparison> typeToComparison;
private final boolean min;

protected AbstractMinMaxByNAggregationFunction(String name, Function<Type, BlockPositionComparison> typeToComparison, String description)
protected AbstractMinMaxByNAggregationFunction(String name, boolean min, String description)
{
super(
new FunctionMetadata(
Expand All @@ -91,34 +96,41 @@ protected AbstractMinMaxByNAggregationFunction(String name, Function<Type, Block
true,
false);
this.name = requireNonNull(name, "name is null");
this.typeToComparison = requireNonNull(typeToComparison, "typeToComparison is null");
this.min = min;
}

@Override
public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding)
{
Type keyType = functionBinding.getTypeVariable("K");
Type valueType = functionBinding.getTypeVariable("V");
return ImmutableList.of(new MinMaxByNStateSerializer(typeToComparison.apply(keyType), keyType, valueType).getSerializedType().getTypeSignature());
return ImmutableList.of(TypedKeyValueHeap.getSerializedType(keyType, valueType).getTypeSignature());
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding)
public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding)
{
return MinMaxCompare.getMinMaxCompareFunctionDependencies(functionBinding.getTypeVariable("K").getTypeSignature(), min);
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
Type keyType = functionBinding.getTypeVariable("K");
Type valueType = functionBinding.getTypeVariable("V");
return generateAggregation(valueType, keyType);
MethodHandle keyComparisonMethod = getMinMaxCompare(functionDependencies, keyType, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), min);
return generateAggregation(keyComparisonMethod, valueType, keyType);
}

public static void input(BlockPositionComparison comparison, Type valueType, Type keyType, MinMaxByNState state, Block value, Block key, int blockIndex, long n)
public static void input(MethodHandle keyComparisonMethod, Type valueType, Type keyType, MinMaxByNState state, Block value, Block key, int blockIndex, long n)
{
TypedKeyValueHeap heap = state.getTypedKeyValueHeap();
if (heap == null) {
if (n <= 0) {
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "third argument of max_by/min_by must be a positive integer");
}
checkCondition(n <= MAX_NUMBER_OF_VALUES, INVALID_FUNCTION_ARGUMENT, "third argument of max_by/min_by must be less than or equal to %s; found %s", MAX_NUMBER_OF_VALUES, n);
heap = new TypedKeyValueHeap(comparison, keyType, valueType, toIntExact(n));
heap = new TypedKeyValueHeap(keyComparisonMethod, keyType, valueType, toIntExact(n));
state.setTypedKeyValueHeap(heap);
}

Expand Down Expand Up @@ -167,13 +179,12 @@ public static void output(ArrayType outputType, MinMaxByNState state, BlockBuild
out.closeEntry();
}

protected InternalAggregationFunction generateAggregation(Type valueType, Type keyType)
protected InternalAggregationFunction generateAggregation(MethodHandle keyComparisonMethod, Type valueType, Type keyType)
{
DynamicClassLoader classLoader = new DynamicClassLoader(AbstractMinMaxNAggregationFunction.class.getClassLoader());

BlockPositionComparison comparison = typeToComparison.apply(keyType);
List<Type> inputTypes = ImmutableList.of(valueType, keyType, BIGINT);
MinMaxByNStateSerializer stateSerializer = new MinMaxByNStateSerializer(comparison, keyType, valueType);
MinMaxByNStateSerializer stateSerializer = new MinMaxByNStateSerializer(keyComparisonMethod, keyType, valueType);
Type intermediateType = stateSerializer.getSerializedType();
ArrayType outputType = new ArrayType(valueType);

Expand All @@ -187,7 +198,7 @@ protected InternalAggregationFunction generateAggregation(Type valueType, Type k
AggregationMetadata metadata = new AggregationMetadata(
generateAggregationName(name, valueType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
inputParameterMetadata,
INPUT_FUNCTION.bindTo(comparison).bindTo(valueType).bindTo(keyType),
INPUT_FUNCTION.bindTo(keyComparisonMethod).bindTo(valueType).bindTo(keyType),
Optional.empty(),
COMBINE_FUNCTION,
OUTPUT_FUNCTION.bindTo(outputType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ public class MaxByNAggregationFunction

public MaxByNAggregationFunction(BlockTypeOperators blockTypeOperators)
{
super(NAME,
blockTypeOperators::getComparisonUnorderedFirstOperator,
"Returns the values of the first argument associated with the maximum values of the second argument");
super(NAME, false, "Returns the values of the first argument associated with the maximum values of the second argument");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ public class MinByNAggregationFunction

public MinByNAggregationFunction(BlockTypeOperators blockTypeOperators)
{
super(NAME,
type -> blockTypeOperators.getComparisonUnorderedLastOperator(type).reversed(),
"Returns the values of the first argument associated with the minimum values of the second argument");
super(NAME, true, "Returns the values of the first argument associated with the minimum values of the second argument");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionComparison;

import java.lang.invoke.MethodHandle;

public class MinMaxByNStateSerializer
implements AccumulatorStateSerializer<MinMaxByNState>
{
private final BlockPositionComparison blockComparison;
private final MethodHandle keyComparisonMethod;
private final Type keyType;
private final Type valueType;
private final Type serializedType;

public MinMaxByNStateSerializer(BlockPositionComparison blockComparison, Type keyType, Type valueType)
public MinMaxByNStateSerializer(MethodHandle keyComparisonMethod, Type keyType, Type valueType)
{
this.blockComparison = blockComparison;
this.keyComparisonMethod = keyComparisonMethod;
this.keyType = keyType;
this.valueType = valueType;
this.serializedType = TypedKeyValueHeap.getSerializedType(keyType, valueType);
Expand Down Expand Up @@ -58,6 +59,6 @@ public void serialize(MinMaxByNState state, BlockBuilder out)
public void deserialize(Block block, int index, MinMaxByNState state)
{
Block currentBlock = (Block) serializedType.getObject(block, index);
state.setTypedKeyValueHeap(TypedKeyValueHeap.deserialize(currentBlock, keyType, valueType, blockComparison));
state.setTypedKeyValueHeap(TypedKeyValueHeap.deserialize(currentBlock, keyType, valueType, keyComparisonMethod));
}
}
13 changes: 13 additions & 0 deletions core/trino-main/src/main/java/io/trino/util/MinMaxCompare.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
Expand Down Expand Up @@ -52,6 +53,18 @@ public static MethodHandle getMinMaxCompare(FunctionDependencies dependencies, T
return filterReturnValue(handle, min ? MIN_FUNCTION : MAX_FUNCTION);
}

public static MethodHandle getMinMaxCompare(TypeOperators typeOperators, Type type, InvocationConvention convention, boolean min)
{
MethodHandle handle;
if (min) {
handle = typeOperators.getComparisonUnorderedLastOperator(type, convention);
}
else {
handle = typeOperators.getComparisonUnorderedFirstOperator(type, convention);
}
return filterReturnValue(handle, min ? MIN_FUNCTION : MAX_FUNCTION);
}

@UsedByGeneratedCode
public static boolean min(long comparisonResult)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,32 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.TypeOperators;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionComparison;
import org.testng.annotations.Test;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.simpleConvention;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.util.MinMaxCompare.getMinMaxCompare;
import static org.testng.Assert.assertEquals;

public class TestTypedKeyValueHeap
{
private static final int INPUT_SIZE = 1_000_000; // larger than COMPACT_THRESHOLD_* to guarantee coverage of compact
private static final int OUTPUT_SIZE = 1_000;

private static final BlockTypeOperators TYPE_OPERATOR_FACTORY = new BlockTypeOperators(new TypeOperators());
private static final BlockPositionComparison MAX_ELEMENTS_COMPARATOR = TYPE_OPERATOR_FACTORY.getComparisonUnorderedFirstOperator(BIGINT);
private static final BlockPositionComparison MIN_ELEMENTS_COMPARATOR = TYPE_OPERATOR_FACTORY.getComparisonUnorderedLastOperator(BIGINT).reversed();
private static final TypeOperators TYPE_OPERATOR_FACTORY = new TypeOperators();
private static final MethodHandle MAX_ELEMENTS_COMPARATOR = getMinMaxCompare(TYPE_OPERATOR_FACTORY, BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), false);
private static final MethodHandle MIN_ELEMENTS_COMPARATOR = getMinMaxCompare(TYPE_OPERATOR_FACTORY, BIGINT, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION, BLOCK_POSITION), true);

@Test
public void testAscending()
Expand Down Expand Up @@ -81,14 +84,14 @@ public void testShuffled()
IntStream.range(0, OUTPUT_SIZE).map(x -> OUTPUT_SIZE - 1 - x).mapToObj(key -> Integer.toString(key * 2)).iterator());
}

private static void test(IntStream keyInputStream, Stream<String> valueInputStream, BlockPositionComparison comparison, Iterator<String> outputIterator)
private static void test(IntStream keyInputStream, Stream<String> valueInputStream, MethodHandle keyComparisonMethod, Iterator<String> outputIterator)
{
BlockBuilder keysBlockBuilder = BIGINT.createBlockBuilder(null, INPUT_SIZE);
BlockBuilder valuesBlockBuilder = VARCHAR.createBlockBuilder(null, INPUT_SIZE);
keyInputStream.forEach(x -> BIGINT.writeLong(keysBlockBuilder, x));
valueInputStream.forEach(x -> VARCHAR.writeString(valuesBlockBuilder, x));

TypedKeyValueHeap heap = new TypedKeyValueHeap(comparison, BIGINT, VARCHAR, OUTPUT_SIZE);
TypedKeyValueHeap heap = new TypedKeyValueHeap(keyComparisonMethod, BIGINT, VARCHAR, OUTPUT_SIZE);
heap.addAll(keysBlockBuilder, valuesBlockBuilder);

BlockBuilder resultBlockBuilder = VARCHAR.createBlockBuilder(null, OUTPUT_SIZE);
Expand Down

0 comments on commit fc3af01

Please sign in to comment.