Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix long decimal partial aggregation below join #21083

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt
return;
}

index = block.getUnderlyingValuePosition(index);
block = block.getUnderlyingValueBlock();
VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block;
Slice slice = variableWidthBlock.getRawSlice();
int sliceOffset = variableWidthBlock.getRawSliceOffset(index);
Expand All @@ -89,15 +91,15 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowAndLongSt
long count = 1;
switch (sliceLength) {
case 4 * Long.BYTES:
overflow = slice.getLong(sliceOffset + 24);
count = slice.getLong(sliceOffset + 16);
overflow = slice.getLong(sliceOffset + Long.BYTES * 3);
count = slice.getLong(sliceOffset + Long.BYTES * 2);
// fall through
case 2 * Long.BYTES:
high = slice.getLong(sliceOffset + 8);
high = slice.getLong(sliceOffset + Long.BYTES);
break;
case 3 * Long.BYTES:
overflow = slice.getLong(sliceOffset + 16);
count = slice.getLong(sliceOffset + 8);
overflow = slice.getLong(sliceOffset + Long.BYTES * 2);
count = slice.getLong(sliceOffset + Long.BYTES);
}

decimal[offset + 1] = low;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta
return;
}

index = block.getUnderlyingValuePosition(index);
block = block.getUnderlyingValueBlock();
VariableWidthBlock variableWidthBlock = (VariableWidthBlock) block;
Slice slice = variableWidthBlock.getRawSlice();
int sliceOffset = variableWidthBlock.getRawSliceOffset(index);
Expand All @@ -75,12 +77,13 @@ public void deserialize(Block block, int index, LongDecimalWithOverflowState sta
long low = slice.getLong(sliceOffset);
long high = 0;
long overflow = 0;
if (sliceLength == 3 * Long.BYTES) {
overflow = slice.getLong(sliceOffset + 16);
high = slice.getLong(sliceOffset + 8);
}
else if (sliceLength == 2 * Long.BYTES) {
high = slice.getLong(sliceOffset + 8);

switch (sliceLength) {
case Long.BYTES * 3:
overflow = slice.getLong(sliceOffset + Long.BYTES * 2);
// fall through
case Long.BYTES * 2:
high = slice.getLong(sliceOffset + Long.BYTES);
}

decimal[offset + 1] = low;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.VariableWidthBlock;
import io.trino.spi.block.VariableWidthBlockBuilder;
import org.junit.jupiter.api.Test;

import java.util.function.Function;

import static org.assertj.core.api.Assertions.assertThat;

public class TestLongDecimalWithOverflowAndLongStateSerializer
Expand Down Expand Up @@ -45,14 +50,19 @@ public void testSerde()
}

private void testSerde(long low, long high, long overflow, long count, int expectedLength)
{
testSerde(low, high, overflow, count, expectedLength, Function.identity());
}

private void testSerde(long low, long high, long overflow, long count, int expectedLength, Function<Block, Block> serializedModification)
{
LongDecimalWithOverflowAndLongState state = STATE_FACTORY.createSingleState();
state.getDecimalArray()[0] = high;
state.getDecimalArray()[1] = low;
state.setOverflow(overflow);
state.setLong(count);

LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength);
LongDecimalWithOverflowAndLongState outState = roundTrip(state, expectedLength, serializedModification);

assertThat(outState.getDecimalArray()[0]).isEqualTo(high);
assertThat(outState.getDecimalArray()[1]).isEqualTo(low);
Expand All @@ -71,7 +81,24 @@ public void testNullSerde()
assertThat(outState.getLong()).isEqualTo(0);
}

@Test
public void testDictionaryDeserialization()
{
testSerde(3, 0, 0, 1, 1, block -> DictionaryBlock.create(2, block, new int[] {0, 0}));
}

@Test
public void testRleDeserialization()
{
testSerde(3, 0, 0, 1, 1, block -> RunLengthEncodedBlock.create(block, 2));
}

private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength)
{
return roundTrip(state, expectedLength, Function.identity());
}

private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAndLongState state, int expectedLength, Function<Block, Block> serializedModification)
{
LongDecimalWithOverflowAndLongStateSerializer serializer = new LongDecimalWithOverflowAndLongStateSerializer();
VariableWidthBlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0);
Expand All @@ -81,7 +108,7 @@ private LongDecimalWithOverflowAndLongState roundTrip(LongDecimalWithOverflowAnd
VariableWidthBlock serialized = out.buildValueBlock();
assertThat(serialized.getSliceLength(0)).isEqualTo(expectedLength * Long.BYTES);
LongDecimalWithOverflowAndLongState outState = STATE_FACTORY.createSingleState();
serializer.deserialize(serialized, 0, outState);
serializer.deserialize(serializedModification.apply(serialized), 0, outState);
return outState;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
*/
package io.trino.operator.aggregation.state;

import io.trino.spi.block.Block;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.VariableWidthBlock;
import io.trino.spi.block.VariableWidthBlockBuilder;
import org.junit.jupiter.api.Test;

import java.util.function.Function;

import static org.assertj.core.api.Assertions.assertThat;

public class TestLongDecimalWithOverflowStateSerializer
Expand All @@ -37,14 +42,19 @@ public void testSerde()
}

private void testSerde(long low, long high, long overflow, int expectedLength)
{
testSerde(low, high, overflow, expectedLength, Function.identity());
}

private void testSerde(long low, long high, long overflow, int expectedLength, Function<Block, Block> serializedModification)
{
LongDecimalWithOverflowState state = STATE_FACTORY.createSingleState();
state.getDecimalArray()[0] = high;
state.getDecimalArray()[1] = low;
state.setOverflow(overflow);
state.setNotNull();

LongDecimalWithOverflowState outState = roundTrip(state, expectedLength);
LongDecimalWithOverflowState outState = roundTrip(state, expectedLength, serializedModification);

assertThat(outState.isNotNull()).isTrue();
assertThat(outState.getDecimalArray()[0]).isEqualTo(high);
Expand All @@ -63,7 +73,24 @@ public void testNullSerde()
assertThat(outState.isNotNull()).isFalse();
}

@Test
public void testDictionaryDeserialization()
{
testSerde(3, 0, 0, 1, block -> DictionaryBlock.create(2, block, new int[] {0, 0}));
}

@Test
public void testRleDeserialization()
{
testSerde(3, 0, 0, 1, block -> RunLengthEncodedBlock.create(block, 2));
}

private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState state, int expectedLength)
{
return roundTrip(state, expectedLength, Function.identity());
}

private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState state, int expectedLength, Function<Block, Block> serializedModification)
{
LongDecimalWithOverflowStateSerializer serializer = new LongDecimalWithOverflowStateSerializer();
VariableWidthBlockBuilder out = new VariableWidthBlockBuilder(null, 1, 0);
Expand All @@ -73,7 +100,7 @@ private LongDecimalWithOverflowState roundTrip(LongDecimalWithOverflowState stat
VariableWidthBlock serialized = out.buildValueBlock();
assertThat(serialized.getSliceLength(0)).isEqualTo(expectedLength * Long.BYTES);
LongDecimalWithOverflowState outState = STATE_FACTORY.createSingleState();
serializer.deserialize(serialized, 0, outState);
serializer.deserialize(serializedModification.apply(serialized), 0, outState);
return outState;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.ZoneId;
import java.time.ZonedDateTime;

import static io.trino.SystemSessionProperties.PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN;
import static io.trino.server.testing.TestingTrinoServer.SESSION_START_TIME_PROPERTY;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
Expand Down Expand Up @@ -72,20 +73,20 @@ void testSpecialDateTimeFunctionsInAggregation()
assertThat(assertions.query(
session,
"""
WITH t(x) AS (VALUES 1)
SELECT max(x), current_timestamp, current_date, current_time, localtimestamp, localtime
FROM t
"""))
WITH t(x) AS (VALUES 1)
SELECT max(x), current_timestamp, current_date, current_time, localtimestamp, localtime
FROM t
"""))
.matches(
"""
VALUES (
1,
TIMESTAMP '2024-03-12 12:24:0.000 Pacific/Apia',
DATE '2024-03-12',
TIME '12:24:0.000+13:00',
TIMESTAMP '2024-03-12 12:24:0.000',
TIME '12:24:0.000')
""");
VALUES (
1,
TIMESTAMP '2024-03-12 12:24:0.000 Pacific/Apia',
DATE '2024-03-12',
TIME '12:24:0.000+13:00',
TIMESTAMP '2024-03-12 12:24:0.000',
TIME '12:24:0.000')
""");
}

/**
Expand All @@ -96,16 +97,47 @@ public void testAggregationMaskOnDictionaryInput()
{
assertThat(assertions.query(
"""
SELECT
max(update_ts) FILTER (WHERE step_type = 'Rest')
FROM (VALUES
('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'),
('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw')
) AS t(cell_id, step_type, update_ts)
-- UNNEST to produce DictionaryBlock
CROSS JOIN UNNEST (sequence(1, 1000)) AS a(e)
GROUP BY cell_id
"""))
SELECT
max(update_ts) FILTER (WHERE step_type = 'Rest')
FROM (VALUES
('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'),
('cell_id', 'Rest', TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw')
) AS t(cell_id, step_type, update_ts)
-- UNNEST to produce DictionaryBlock
CROSS JOIN UNNEST (sequence(1, 1000)) AS a(e)
GROUP BY cell_id
"""))
.matches("VALUES TIMESTAMP '2005-09-10 13:31:00.123 Europe/Warsaw'");
}

/**
* Regression test for <a href="https://github.com/trinodb/trino/issues/21099">#21099</a>
*/
@Test
public void testLongDecimalPartialAggregation()
{
Session session = Session.builder(assertions.getDefaultSession())
.setSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, "true")
.build();
assertThat(assertions.query(session, """
-- sum and avg likely use different aggregation state classes
SELECT l.i, sum(v), avg(v)
FROM (VALUES 1, 2, 2, 3, 3, 3) l(i)
JOIN (VALUES
(1, DECIMAL '12345678901234567890.1234567890'),
(1, DECIMAL '11111111111111111111.1234567890'),
(2, DECIMAL '22222222222222222222.1234567890'),
(3, DECIMAL '33333333333333333333.1234567890'),
(3, DECIMAL '10101010101010101010.0987654321'),
(7, DECIMAL '77777777777777777777.1234567890')) r(i, v) ON l.i = r.i
GROUP BY l.i
"""))
.matches("""
SELECT i, CAST(s AS decimal(38, 10)), v
FROM (VALUES
(1, DECIMAL '23456790012345679001.2469135780', DECIMAL '11728395006172839500.6234567890'),
(2, DECIMAL '44444444444444444444.2469135780', DECIMAL '22222222222222222222.1234567890'),
(3, DECIMAL '130303030303030303029.6666666633', DECIMAL '21717171717171717171.6111111106')) t(i, s, v)
""");
}
}