diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java index 21cc1f4f8faf..1181a7f67bd5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowAndLongStateSerializer.java @@ -74,6 +74,8 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt return; } + index = block.getUnderlyingValuePosition(index); + block = block.getUnderlyingValueBlock(); VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block; Slice slice = variableWidthBlock.getRawSlice(); int sliceOffset = variableWidthBlock.getRawSliceOffset(index); @@ -89,15 +91,15 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt long count = 1; switch (sliceLength) { case 4 * Long.BYTES: - overflow = slice.getLong(sliceOffset + 24); - count = slice.getLong(sliceOffset + 16); + overflow = slice.getLong(sliceOffset + Long.BYTES * 3); + count = slice.getLong(sliceOffset + Long.BYTES * 2); // fall through case 2 * Long.BYTES: - high = slice.getLong(sliceOffset + 8); + high = slice.getLong(sliceOffset + Long.BYTES); break; case 3 * Long.BYTES: - overflow = slice.getLong(sliceOffset + 16); - count = slice.getLong(sliceOffset + 8); + overflow = slice.getLong(sliceOffset + Long.BYTES * 2); + count = slice.getLong(sliceOffset + Long.BYTES); } decimal[offset + 1] = low; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java index 341ea85359c6..ab1103ce96f8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/state/LongDecimalWithOverflowStateSerializer.java @@ -64,6 +64,8 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta return; } + index = block.getUnderlyingValuePosition(index); + block = block.getUnderlyingValueBlock(); VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block; Slice slice = variableWidthBlock.getRawSlice(); int sliceOffset = variableWidthBlock.getRawSliceOffset(index); @@ -75,12 +77,13 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta long low = slice.getLong(sliceOffset); long high = 0; long overflow = 0; - if (sliceLength == 3 * Long.BYTES) { - overflow = slice.getLong(sliceOffset + 16); - high = slice.getLong(sliceOffset + 8); - } - else if (sliceLength == 2 * Long.BYTES) { - high = slice.getLong(sliceOffset + 8); + + switch (sliceLength) { + case Long.BYTES * 3: + overflow = slice.getLong(sliceOffset + Long.BYTES * 2); + // fall through + case Long.BYTES * 2: + high = slice.getLong(sliceOffset + Long.BYTES); } decimal[offset + 1] = low; diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java index 26bf3969b04a..98fe547b38b2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowAndLongStateSerializer.java @@ -13,10 +13,15 @@ */ package io.trino.operator.aggregation.state; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import org.junit.jupiter.api.Test; +import java.util.function.Function; + import static org.assertj.core.api.Assertions.assertThat; public class TestLongDecimalWithOverflowAndLongStateSerializer @@ -45,6 +50,11 @@ public void testSerde() } private void testSerde(long low, long high, long overflow, long count, int expectedLength) + { + testSerde(low, high, overflow, count, expectedLength, Function.identity()); + } + + private void testSerde(long low, long high, long overflow, long count, int expectedLength, Function serializedModification) { LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState(); state.getDecimalArray()[0] = high; @@ -52,7 +62,7 @@ private void testSerde(long low, long high, long overflow, long count, int expec state.setOverflow(overflow); state.setLong(count); - LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength); + LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength, serializedModification); assertThat(outState.getDecimalArray()[0]).isEqualTo(high); assertThat(outState.getDecimalArray()[1]).isEqualTo(low); @@ -71,7 +81,24 @@ public void testNullSerde() assertThat(outState.getLong()).isEqualTo(0); } + @Test + public void testDictionaryDeserialization() + { + testSerde(3, 0, 0, 1, 1, block -> DictionaryBlock.create(2, block, new int[] {0, 0})); + } + + @Test + public void testRleDeserialization() + { + testSerde(3, 0, 0, 1, 1, block -> RunLengthEncodedBlock.create(block, 2)); + } + private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength) + { + return roundTrip(state, expectedLength, Function.identity()); + } + + private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength, Function serializedModification) { LongDecimalWithOverflowAndLongStateSerializer serializer = new LongDecimalWithOverflowAndLongStateSerializer(); VariableWidthBlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0); @@ -81,7 +108,7 @@ private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAnd VariableWidthBlock serialized = out.buildValueBlock(); assertThat(serialized.getSliceLength(0)).isEqualTo(expectedLength * Long.BYTES); LongDecimalWithOverflowAndLongState outState = STATE_FACTORY.createSingleState(); - serializer.deserialize(serialized, 0, outState); + serializer.deserialize(serializedModification.apply(serialized), 0, outState); return outState; } } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java index 6113ebea3e27..53746309205b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/state/TestLongDecimalWithOverflowStateSerializer.java @@ -13,10 +13,15 @@ */ package io.trino.operator.aggregation.state; +import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.block.VariableWidthBlockBuilder; import org.junit.jupiter.api.Test; +import java.util.function.Function; + import static org.assertj.core.api.Assertions.assertThat; public class TestLongDecimalWithOverflowStateSerializer @@ -37,6 +42,11 @@ public void testSerde() } private void testSerde(long low, long high, long overflow, int expectedLength) + { + testSerde(low, high, overflow, expectedLength, Function.identity()); + } + + private void testSerde(long low, long high, long overflow, int expectedLength, Function serializedModification) { LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState(); state.getDecimalArray()[0] = high; @@ -44,7 +54,7 @@ private void testSerde(long low, long high, long overflow, int expectedLength) state.setOverflow(overflow); state.setNotNull(); - LongDecimalWithOverflowState outState = roundTrip(state, expectedLength); + LongDecimalWithOverflowState outState = roundTrip(state, expectedLength, serializedModification); assertThat(outState.isNotNull()).isTrue(); assertThat(outState.getDecimalArray()[0]).isEqualTo(high); @@ -63,7 +73,24 @@ public void testNullSerde() assertThat(outState.isNotNull()).isFalse(); } + @Test + public void testDictionaryDeserialization() + { + testSerde(3, 0, 0, 1, block -> DictionaryBlock.create(2, block, new int[] {0, 0})); + } + + @Test + public void testRleDeserialization() + { + testSerde(3, 0, 0, 1, block -> RunLengthEncodedBlock.create(block, 2)); + } + private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState state, int expectedLength) + { + return roundTrip(state, expectedLength, Function.identity()); + } + + private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState state, int expectedLength, Function serializedModification) { LongDecimalWithOverflowStateSerializer serializer = new LongDecimalWithOverflowStateSerializer(); VariableWidthBlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0); @@ -73,7 +100,7 @@ private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState stat VariableWidthBlock serialized = out.buildValueBlock(); assertThat(serialized.getSliceLength(0)).isEqualTo(expectedLength * Long.BYTES); LongDecimalWithOverflowState outState = STATE_FACTORY.createSingleState(); - serializer.deserialize(serialized, 0, outState); + serializer.deserialize(serializedModification.apply(serialized), 0, outState); return outState; } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java b/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java index a0d6b67dffc6..c990a6aadacc 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestAggregation.java @@ -22,6 +22,7 @@ import java.time.ZoneId; import java.time.ZonedDateTime; +import static io.trino.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN; import static io.trino.server.testing.TestingTrinoServer.SESSION_START_TIME_PROPERTY; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @@ -72,20 +73,20 @@ void testSpecialDateTimeFunctionsInAggregation() assertThat(assertions.query( session, """ - WITH t(x) AS (VALUES 1) - SELECT max(x), current_timestamp, current_date, current_time, localtimestamp, localtime - FROM t - """)) + WITH t(x) AS (VALUES 1) + SELECT max(x), current_timestamp, current_date, current_time, localtimestamp, localtime + FROM t + """)) .matches( """ - VALUES ( - 1, - TIMESTAMP '2024-03-12 12:24:0.000 Pacific/Apia', - DATE '2024-03-12', - TIME '12:24:0.000+13:00', - TIMESTAMP '2024-03-12 12:24:0.000', - TIME '12:24:0.000') - """); + VALUES ( + 1, + TIMESTAMP '2024-03-12 12:24:0.000 Pacific/Apia', + DATE '2024-03-12', + TIME '12:24:0.000+13:00', + TIMESTAMP '2024-03-12 12:24:0.000', + TIME '12:24:0.000') + """); } /** @@ -96,16 +97,47 @@ public void testAggregationMaskOnDictionaryInput() { assertThat(assertions.query( """ - SELECT - max(update_ts) FILTER (WHERE step_type = 'Rest') - FROM (VALUES - ('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'), - ('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw') - ) AS t(cell_id, step_type, update_ts) - -- UNNEST to produce DictionaryBlock - CROSS JOIN UNNEST (sequence(1, 1000)) AS a(e) - GROUP BY cell_id - """)) + SELECT + max(update_ts) FILTER (WHERE step_type = 'Rest') + FROM (VALUES + ('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'), + ('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw') + ) AS t(cell_id, step_type, update_ts) + -- UNNEST to produce DictionaryBlock + CROSS JOIN UNNEST (sequence(1, 1000)) AS a(e) + GROUP BY cell_id + """)) .matches("VALUES TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'"); } + + /** + * Regression test for #21099 + */ + @Test + public void testLongDecimalPartialAggregation() + { + Session session = Session.builder(assertions.getDefaultSession()) + .setSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "true") + .build(); + assertThat(assertions.query(session, """ + -- sum and avg likely use different aggregation state classes + SELECT l.i, sum(v), avg(v) + FROM (VALUES 1, 2, 2, 3, 3, 3) l(i) + JOIN (VALUES + (1, DECIMAL '12345678901234567890.1234567890'), + (1, DECIMAL '11111111111111111111.1234567890'), + (2, DECIMAL '22222222222222222222.1234567890'), + (3, DECIMAL '33333333333333333333.1234567890'), + (3, DECIMAL '10101010101010101010.0987654321'), + (7, DECIMAL '77777777777777777777.1234567890')) r(i, v) ON l.i = r.i + GROUP BY l.i + """)) + .matches(""" + SELECT i, CAST(s AS decimal(38, 10)), v + FROM (VALUES + (1, DECIMAL '23456790012345679001.2469135780', DECIMAL '11728395006172839500.6234567890'), + (2, DECIMAL '44444444444444444444.2469135780', DECIMAL '22222222222222222222.1234567890'), + (3, DECIMAL '130303030303030303029.6666666633', DECIMAL '21717171717171717171.6111111106')) t(i, s, v) + """); + } }