diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java index 9fe86559908a..7e3892f88a3a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java @@ -30,8 +30,8 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.type.BlockTypeOperators; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -40,7 +40,7 @@ import static io.trino.metadata.FunctionKind.AGGREGATE; import static io.trino.metadata.Signature.comparableTypeParameter; import static io.trino.metadata.Signature.typeVariable; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.type.TypeSignature.arrayType; import static io.trino.spi.type.TypeSignature.mapType; import static io.trino.spi.type.TypeSignature.rowType; @@ -57,7 +57,7 @@ public class MultimapAggregationFunction MultimapAggregationFunction.class, "output", Type.class, - BlockPositionEqual.class, + BlockPositionIsDistinctFrom.class, BlockPositionHashCode.class, Type.class, MultimapAggregationState.class, @@ -103,7 +103,7 @@ public MultimapAggregationFunction(BlockTypeOperators blockTypeOperators) public AggregationMetadata specialize(BoundSignature boundSignature) { Type keyType = boundSignature.getArgumentType(0); - BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType); + BlockPositionIsDistinctFrom keyDistinctOperator = blockTypeOperators.getDistinctFromOperator(keyType); BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType); Type valueType = boundSignature.getArgumentType(1); @@ -114,7 +114,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature) INPUT_FUNCTION, Optional.empty(), Optional.of(COMBINE_FUNCTION), - MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType), + MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyDistinctOperator, keyHashCode, valueType), ImmutableList.of(new AccumulatorStateDescriptor<>( MultimapAggregationState.class, stateSerializer, @@ -131,7 +131,7 @@ public static void combine(MultimapAggregationState state, MultimapAggregationSt state.merge(otherState); } - public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out) + public static void output(Type keyType, BlockPositionIsDistinctFrom keyDistinctOperator, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out) { if (state.isEmpty()) { out.appendNull(); @@ -141,7 +141,7 @@ public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositi ObjectBigArray valueArrayBlockBuilders = new ObjectBigArray<>(); valueArrayBlockBuilders.ensureCapacity(state.getEntryCount()); BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, state.getEntryCount(), expectedValueSize(keyType, 100)); - TypedSet keySet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, state.getEntryCount(), NAME); + TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctOperator, keyHashCode, state.getEntryCount(), NAME); state.forEach((key, value, keyValueIndex) -> { // Merge values of the same key into an array diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java index c3d5174c62e8..0a6cc2d2d623 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java @@ -88,6 +88,13 @@ public void testNullMap() testMultimapAgg(DOUBLE, ImmutableList.of(), VARCHAR, ImmutableList.of()); } + @Test + public void testKeysUseIsDistinctSemantics() + { + testMultimapAgg(DOUBLE, ImmutableList.of(Double.NaN, Double.NaN), BIGINT, ImmutableList.of(1L, 1L)); + testMultimapAgg(DOUBLE, ImmutableList.of(Double.NaN, Double.NaN, Double.NaN), BIGINT, ImmutableList.of(2L, 1L, 2L)); + } + @Test public void testDoubleMapMultimap() { @@ -177,6 +184,11 @@ private static TestingAggregationFunction getAggregationFunction(Type keyType, T return FUNCTION_RESOLUTION.getAggregateFunction(QualifiedName.of(MultimapAggregationFunction.NAME), fromTypes(keyType, valueType)); } + /** + * Given a list of keys and a list of corresponding values, manually + * aggregate them into a map of list and check that Trino's aggregation has + * the same results. + */ private static void testMultimapAgg(Type keyType, List expectedKeys, Type valueType, List expectedValues) { checkState(expectedKeys.size() == expectedValues.size(), "expectedKeys and expectedValues should have equal size");