Skip to content

Commit

Permalink
Make multimap_agg compare keys with 'is distinct' semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
jirassimok authored and losipiuk committed Apr 7, 2022
1 parent 013d110 commit 1e2f094
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
Expand All @@ -40,7 +40,7 @@
import static io.trino.metadata.FunctionKind.AGGREGATE;
import static io.trino.metadata.Signature.comparableTypeParameter;
import static io.trino.metadata.Signature.typeVariable;
import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet;
import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.type.TypeSignature.arrayType;
import static io.trino.spi.type.TypeSignature.mapType;
import static io.trino.spi.type.TypeSignature.rowType;
Expand All @@ -57,7 +57,7 @@ public class MultimapAggregationFunction
MultimapAggregationFunction.class,
"output",
Type.class,
BlockPositionEqual.class,
BlockPositionIsDistinctFrom.class,
BlockPositionHashCode.class,
Type.class,
MultimapAggregationState.class,
Expand Down Expand Up @@ -103,7 +103,7 @@ public MultimapAggregationFunction(BlockTypeOperators blockTypeOperators)
public AggregationMetadata specialize(BoundSignature boundSignature)
{
Type keyType = boundSignature.getArgumentType(0);
BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType);
BlockPositionIsDistinctFrom keyDistinctOperator = blockTypeOperators.getDistinctFromOperator(keyType);
BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType);

Type valueType = boundSignature.getArgumentType(1);
Expand All @@ -114,7 +114,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature)
INPUT_FUNCTION,
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType),
MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyDistinctOperator, keyHashCode, valueType),
ImmutableList.of(new AccumulatorStateDescriptor<>(
MultimapAggregationState.class,
stateSerializer,
Expand All @@ -131,7 +131,7 @@ public static void combine(MultimapAggregationState state, MultimapAggregationSt
state.merge(otherState);
}

public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out)
public static void output(Type keyType, BlockPositionIsDistinctFrom keyDistinctOperator, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out)
{
if (state.isEmpty()) {
out.appendNull();
Expand All @@ -141,7 +141,7 @@ public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositi
ObjectBigArray<BlockBuilder> valueArrayBlockBuilders = new ObjectBigArray<>();
valueArrayBlockBuilders.ensureCapacity(state.getEntryCount());
BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, state.getEntryCount(), expectedValueSize(keyType, 100));
TypedSet keySet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, state.getEntryCount(), NAME);
TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctOperator, keyHashCode, state.getEntryCount(), NAME);

state.forEach((key, value, keyValueIndex) -> {
// Merge values of the same key into an array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ public void testNullMap()
testMultimapAgg(DOUBLE, ImmutableList.<Double>of(), VARCHAR, ImmutableList.<String>of());
}

@Test
public void testKeysUseIsDistinctSemantics()
{
testMultimapAgg(DOUBLE, ImmutableList.of(Double.NaN, Double.NaN), BIGINT, ImmutableList.of(1L, 1L));
testMultimapAgg(DOUBLE, ImmutableList.of(Double.NaN, Double.NaN, Double.NaN), BIGINT, ImmutableList.of(2L, 1L, 2L));
}

@Test
public void testDoubleMapMultimap()
{
Expand Down Expand Up @@ -177,6 +184,11 @@ private static TestingAggregationFunction getAggregationFunction(Type keyType, T
return FUNCTION_RESOLUTION.getAggregateFunction(QualifiedName.of(MultimapAggregationFunction.NAME), fromTypes(keyType, valueType));
}

/**
* Given a list of keys and a list of corresponding values, manually
* aggregate them into a map of list and check that Trino's aggregation has
* the same results.
*/
private static <K, V> void testMultimapAgg(Type keyType, List<K> expectedKeys, Type valueType, List<V> expectedValues)
{
checkState(expectedKeys.size() == expectedValues.size(), "expectedKeys and expectedValues should have equal size");
Expand Down

0 comments on commit 1e2f094

Please sign in to comment.