Skip to content

Commit

Permalink
Improve test coverage of BigintGroupByHash and MultiChannelGroupByHash
Browse files Browse the repository at this point in the history
  • Loading branch information
sopel39 committed May 31, 2022
1 parent 5771df6 commit bbe9441
Showing 1 changed file with 81 additions and 44 deletions.
125 changes: 81 additions & 44 deletions core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.spi.block.LongArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.VariableWidthBlock;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.gen.JoinCompiler;
Expand Down Expand Up @@ -73,10 +74,46 @@ public Object[][] dataType()
return new Object[][] {{VARCHAR}, {BIGINT}};
}

@Test
public void testAddPage()
@DataProvider
public Object[][] groupByHashType()
{
return new Object[][] {{GroupByHashType.BIGINT}, {GroupByHashType.MULTI_CHANNEL}};
}

private enum GroupByHashType
{
BIGINT, MULTI_CHANNEL;

public GroupByHash createGroupByHash()
{
return createGroupByHash(100, NOOP);
}

public GroupByHash createGroupByHash(int expectedSize, UpdateMemory updateMemory)
{
switch (this) {
case BIGINT:
return new BigintGroupByHash(0, true, expectedSize, updateMemory);
case MULTI_CHANNEL:
return new MultiChannelGroupByHash(
ImmutableList.of(BigintType.BIGINT),
new int[] {0},
Optional.of(1),
expectedSize,
true,
JOIN_COMPILER,
TYPE_OPERATOR_FACTORY,
updateMemory);
}

throw new UnsupportedOperationException();
}
}

@Test(dataProvider = "groupByHashType")
public void testAddPage(GroupByHashType groupByHashType)
{
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
GroupByHash groupByHash = groupByHashType.createGroupByHash();
for (int tries = 0; tries < 2; tries++) {
for (int value = 0; value < MAX_GROUP_ID; value++) {
Block block = BlockAssertions.createLongsBlock(value);
Expand All @@ -102,10 +139,10 @@ public void testAddPage()
}
}

@Test
public void testRunLengthEncodedBigintGroupByHash()
@Test(dataProvider = "groupByHashType")
public void testRunLengthEncodedInputPage(GroupByHashType groupByHashType)
{
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
GroupByHash groupByHash = groupByHashType.createGroupByHash();
Block block = BlockAssertions.createLongsBlock(0L);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), block);
Page page = new Page(
Expand All @@ -130,10 +167,10 @@ public void testRunLengthEncodedBigintGroupByHash()
assertTrue(children.get(0) instanceof RunLengthEncodedBlock);
}

@Test
public void testDictionaryBigintGroupByHash()
@Test(dataProvider = "groupByHashType")
public void testDictionaryInputPage(GroupByHashType groupByHashType)
{
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
GroupByHash groupByHash = groupByHashType.createGroupByHash();
Block block = BlockAssertions.createLongsBlock(0L, 1L);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), block);
int[] ids = new int[] {0, 0, 1, 1};
Expand All @@ -157,10 +194,10 @@ public void testDictionaryBigintGroupByHash()
assertEquals(groupIds.getGroupId(3), 1);
}

@Test
public void testNullGroup()
@Test(dataProvider = "groupByHashType")
public void testNullGroup(GroupByHashType groupByHashType)
{
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
GroupByHash groupByHash = groupByHashType.createGroupByHash();

Block block = createLongsBlock((Long) null);
Block hashBlock = getHashBlock(ImmutableList.of(BIGINT), block);
Expand All @@ -179,10 +216,10 @@ public void testNullGroup()
assertFalse(groupByHash.contains(0, page, CONTAINS_CHANNELS));
}

@Test
public void testGetGroupIds()
@Test(dataProvider = "groupByHashType")
public void testGetGroupIds(GroupByHashType groupByHashType)
{
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
GroupByHash groupByHash = groupByHashType.createGroupByHash();
for (int tries = 0; tries < 2; tries++) {
for (int value = 0; value < MAX_GROUP_ID; value++) {
Block block = BlockAssertions.createLongsBlock(value);
Expand All @@ -209,12 +246,12 @@ public void testTypes()
assertEquals(groupByHash.getTypes(), ImmutableList.of(VARCHAR, BIGINT));
}

@Test
public void testAppendTo()
@Test(dataProvider = "groupByHashType")
public void testAppendTo(GroupByHashType groupByHashType)
{
Block valuesBlock = BlockAssertions.createStringSequenceBlock(0, 100);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(VARCHAR), valuesBlock);
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(VARCHAR), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
Block valuesBlock = BlockAssertions.createLongSequenceBlock(0, 100);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), valuesBlock);
GroupByHash groupByHash = groupByHashType.createGroupByHash();

Work<GroupByIdBlock> work = groupByHash.getGroupIds(new Page(valuesBlock, hashBlock));
work.process();
Expand All @@ -235,12 +272,12 @@ public void testAppendTo()
assertEquals(page.getBlock(i).getPositionCount(), 100);
}
assertEquals(page.getPositionCount(), 100);
BlockAssertions.assertBlockEquals(VARCHAR, page.getBlock(0), valuesBlock);
BlockAssertions.assertBlockEquals(BIGINT, page.getBlock(0), valuesBlock);
BlockAssertions.assertBlockEquals(BIGINT, page.getBlock(1), hashBlock);
}

@Test
public void testAppendToMultipleTuplesPerGroup()
@Test(dataProvider = "groupByHashType")
public void testAppendToMultipleTuplesPerGroup(GroupByHashType groupByHashType)
{
List<Long> values = new ArrayList<>();
for (long i = 0; i < 100; i++) {
Expand All @@ -249,7 +286,7 @@ public void testAppendToMultipleTuplesPerGroup()
Block valuesBlock = BlockAssertions.createLongsBlock(values);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), valuesBlock);

GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
GroupByHash groupByHash = groupByHashType.createGroupByHash();
groupByHash.getGroupIds(new Page(valuesBlock, hashBlock)).process();
assertEquals(groupByHash.getGroupCount(), 50);

Expand All @@ -263,20 +300,20 @@ public void testAppendToMultipleTuplesPerGroup()
BlockAssertions.assertBlockEquals(BIGINT, outputPage.getBlock(0), BlockAssertions.createLongSequenceBlock(0, 50));
}

@Test
public void testContains()
@Test(dataProvider = "groupByHashType")
public void testContains(GroupByHashType groupByHashType)
{
Block valuesBlock = BlockAssertions.createDoubleSequenceBlock(0, 10);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(DOUBLE), valuesBlock);
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(DOUBLE), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
Block valuesBlock = BlockAssertions.createLongSequenceBlock(0, 10);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), valuesBlock);
GroupByHash groupByHash = groupByHashType.createGroupByHash();
groupByHash.getGroupIds(new Page(valuesBlock, hashBlock)).process();

Block testBlock = BlockAssertions.createDoublesBlock((double) 3);
Block testHashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(DOUBLE), testBlock);
Block testBlock = BlockAssertions.createLongsBlock(3);
Block testHashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), testBlock);
assertTrue(groupByHash.contains(0, new Page(testBlock, testHashBlock), CONTAINS_CHANNELS));

testBlock = BlockAssertions.createDoublesBlock(11.0);
testHashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(DOUBLE), testBlock);
testBlock = BlockAssertions.createLongsBlock(11);
testHashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), testBlock);
assertFalse(groupByHash.contains(0, new Page(testBlock, testHashBlock), CONTAINS_CHANNELS));
}

Expand All @@ -296,15 +333,15 @@ public void testContainsMultipleColumns()
assertTrue(groupByHash.contains(0, new Page(testValuesBlock, testStringValuesBlock, testHashBlock), hashChannels));
}

@Test
public void testForceRehash()
@Test(dataProvider = "groupByHashType")
public void testForceRehash(GroupByHashType groupByHashType)
{
// Create a page with positionCount >> expected size of groupByHash
Block valuesBlock = BlockAssertions.createStringSequenceBlock(0, 100);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(VARCHAR), valuesBlock);
Block valuesBlock = BlockAssertions.createLongSequenceBlock(0, 100);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), valuesBlock);

// Create group by hash with extremely small size
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(VARCHAR), new int[] {0}, Optional.of(1), 4, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, NOOP);
GroupByHash groupByHash = groupByHashType.createGroupByHash(4, NOOP);
groupByHash.getGroupIds(new Page(valuesBlock, hashBlock)).process();

// Ensure that all groups are present in group by hash
Expand Down Expand Up @@ -431,16 +468,16 @@ else if (type == BIGINT) {
assertEquals(currentQuota.get() / 3, yields);
}

@Test
public void testMemoryReservationYieldWithDictionary()
@Test(dataProvider = "groupByHashType")
public void testMemoryReservationYieldWithDictionary(GroupByHashType groupByHashType)
{
// Create a page with positionCount >> expected size of groupByHash
int dictionaryLength = 1_000;
int length = 2_000_000;
int[] ids = IntStream.range(0, dictionaryLength).toArray();
DictionaryId dictionaryId = randomDictionaryId();
Block valuesBlock = new DictionaryBlock(dictionaryLength, createStringSequenceBlock(0, length), ids, dictionaryId);
Block hashBlock = new DictionaryBlock(dictionaryLength, getHashBlock(ImmutableList.of(VARCHAR), valuesBlock), ids, dictionaryId);
Block valuesBlock = new DictionaryBlock(dictionaryLength, createLongSequenceBlock(0, length), ids, dictionaryId);
Block hashBlock = new DictionaryBlock(dictionaryLength, getHashBlock(ImmutableList.of(BIGINT), valuesBlock), ids, dictionaryId);
Page page = new Page(valuesBlock, hashBlock);
AtomicInteger currentQuota = new AtomicInteger(0);
AtomicInteger allowedQuota = new AtomicInteger(3);
Expand All @@ -454,7 +491,7 @@ public void testMemoryReservationYieldWithDictionary()
int yields = 0;

// test addPage
GroupByHash groupByHash = createGroupByHash(ImmutableList.of(VARCHAR), new int[] {0}, Optional.of(1), 1, true, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, updateMemory);
GroupByHash groupByHash = groupByHashType.createGroupByHash(1, updateMemory);

boolean finish = false;
Work<?> addPageWork = groupByHash.addPage(page);
Expand Down Expand Up @@ -482,7 +519,7 @@ public void testMemoryReservationYieldWithDictionary()
currentQuota.set(0);
allowedQuota.set(3);
yields = 0;
groupByHash = createGroupByHash(ImmutableList.of(VARCHAR), new int[] {0}, Optional.of(1), 1, true, JOIN_COMPILER, TYPE_OPERATOR_FACTORY, updateMemory);
groupByHash = groupByHashType.createGroupByHash(1, updateMemory);

finish = false;
Work<GroupByIdBlock> getGroupIdsWork = groupByHash.getGroupIds(page);
Expand Down

0 comments on commit bbe9441

Please sign in to comment.