From 13b57645621efc6ba3e4cce81bb11c369eec813d Mon Sep 17 00:00:00 2001 From: skrzypo987 Date: Thu, 9 Mar 2023 10:43:06 +0100 Subject: [PATCH 1/2] Move static methods to the end of the class --- .../trino/operator/join/BigintPagesHash.java | 28 +++++----- .../trino/operator/join/DefaultPagesHash.java | 28 +++++----- .../trino/operator/join/JoinHashSupplier.java | 52 +++++++++---------- 3 files changed, 54 insertions(+), 54 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java index eda1f49cc5fa..b4c36d600c4a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java @@ -58,20 +58,6 @@ public final class BigintPagesHash private final long[] values; private final long size; - public static long getEstimatedRetainedSizeInBytes( - int positionCount, - HashArraySizeSupplier hashArraySizeSupplier, - LongArrayList addresses, - List> channels, - long blocksSizeInBytes) - { - return sizeOf(addresses.elements()) + - (channels.size() > 0 ? sizeOf(channels.get(0).elements()) * channels.size() : 0) + - blocksSizeInBytes + - sizeOfIntArray(hashArraySizeSupplier.getHashArraySize(positionCount)) + - sizeOfLongArray(positionCount); - } - public BigintPagesHash( LongArrayList addresses, PagesHashStrategy pagesHashStrategy, @@ -267,4 +253,18 @@ private boolean isPositionNull(int position) return joinChannelBlocks.get(blockIndex).isNull(blockPosition); } + + public static long getEstimatedRetainedSizeInBytes( + int positionCount, + HashArraySizeSupplier hashArraySizeSupplier, + LongArrayList addresses, + List> channels, + long blocksSizeInBytes) + { + return sizeOf(addresses.elements()) + + (channels.size() > 0 ? sizeOf(channels.get(0).elements()) * channels.size() : 0) + + blocksSizeInBytes + + sizeOfIntArray(hashArraySizeSupplier.getHashArraySize(positionCount)) + + sizeOfLongArray(positionCount); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java index 92e562376e71..4acdbdca673b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java @@ -58,20 +58,6 @@ public final class DefaultPagesHash // and there is no performance gain from storing full hashes private final byte[] positionToHashes; - public static long getEstimatedRetainedSizeInBytes( - int positionCount, - HashArraySizeSupplier hashArraySizeSupplier, - LongArrayList addresses, - List> channels, - long blocksSizeInBytes) - { - return sizeOf(addresses.elements()) + - (channels.size() > 0 ? sizeOf(channels.get(0).elements()) * channels.size() : 0) + - blocksSizeInBytes + - sizeOfIntArray(hashArraySizeSupplier.getHashArraySize(positionCount)) + - sizeOfByteArray(positionCount); - } - public DefaultPagesHash( LongArrayList addresses, PagesHashStrategy pagesHashStrategy, @@ -300,4 +286,18 @@ private boolean positionEqualsPositionIgnoreNulls(int leftPosition, int rightPos return pagesHashStrategy.positionEqualsPositionIgnoreNulls(leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition); } + + public static long getEstimatedRetainedSizeInBytes( + int positionCount, + HashArraySizeSupplier hashArraySizeSupplier, + LongArrayList addresses, + List> channels, + long blocksSizeInBytes) + { + return sizeOf(addresses.elements()) + + (channels.size() > 0 ? sizeOf(channels.get(0).elements()) * channels.size() : 0) + + blocksSizeInBytes + + sizeOfIntArray(hashArraySizeSupplier.getHashArraySize(positionCount)) + + sizeOfByteArray(positionCount); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java b/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java index 19e60084f540..7e43669be489 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/JoinHashSupplier.java @@ -53,32 +53,6 @@ public class JoinHashSupplier private final Optional filterFunctionFactory; private final List searchFunctionFactories; - public static long getEstimatedRetainedSizeInBytes( - int positionCount, - LongArrayList addresses, - List> channels, - long blocksSizeInBytes, - Optional sortChannel, - OptionalInt singleBigintJoinChannel, - HashArraySizeSupplier hashArraySizeSupplier) - { - long result = 0; - if (sortChannel.isPresent()) { - result += SortedPositionLinks.getEstimatedRetainedSizeInBytes(positionCount); - } - else { - result += ArrayPositionLinks.getEstimatedRetainedSizeInBytes(positionCount); - } - result += getPageInstancesRetainedSizeInBytes(channels); - if (singleBigintJoinChannel.isPresent() && addresses.size() <= JOIN_POSITIONS_ARRAY_CUTOFF) { - result += BigintPagesHash.getEstimatedRetainedSizeInBytes(positionCount, hashArraySizeSupplier, addresses, channels, blocksSizeInBytes); - } - else { - result += DefaultPagesHash.getEstimatedRetainedSizeInBytes(positionCount, hashArraySizeSupplier, addresses, channels, blocksSizeInBytes); - } - return result; - } - public JoinHashSupplier( Session session, PagesHashStrategy pagesHashStrategy, @@ -146,6 +120,32 @@ public JoinHash get() pageInstancesRetainedSizeInBytes); } + public static long getEstimatedRetainedSizeInBytes( + int positionCount, + LongArrayList addresses, + List> channels, + long blocksSizeInBytes, + Optional sortChannel, + OptionalInt singleBigintJoinChannel, + HashArraySizeSupplier hashArraySizeSupplier) + { + long result = 0; + if (sortChannel.isPresent()) { + result += SortedPositionLinks.getEstimatedRetainedSizeInBytes(positionCount); + } + else { + result += ArrayPositionLinks.getEstimatedRetainedSizeInBytes(positionCount); + } + result += getPageInstancesRetainedSizeInBytes(channels); + if (singleBigintJoinChannel.isPresent() && addresses.size() <= JOIN_POSITIONS_ARRAY_CUTOFF) { + result += BigintPagesHash.getEstimatedRetainedSizeInBytes(positionCount, hashArraySizeSupplier, addresses, channels, blocksSizeInBytes); + } + else { + result += DefaultPagesHash.getEstimatedRetainedSizeInBytes(positionCount, hashArraySizeSupplier, addresses, channels, blocksSizeInBytes); + } + return result; + } + private static long getPageInstancesRetainedSizeInBytes(List> channels) { if (channels.isEmpty()) { From e2b8a2950ea4702b376b86222105e99fc7e21cce Mon Sep 17 00:00:00 2001 From: skrzypo987 Date: Thu, 9 Mar 2023 11:18:59 +0100 Subject: [PATCH 2/2] Extract methods for better readability JFR profiling will also be easier when code is split into more methods --- .../trino/operator/join/BigintPagesHash.java | 131 ++++++++++------ .../trino/operator/join/DefaultPagesHash.java | 145 ++++++++++++------ 2 files changed, 176 insertions(+), 100 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java index b4c36d600c4a..e7bf4a15cbbc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java @@ -91,42 +91,52 @@ public BigintPagesHash( int stepEndPosition = Math.min((step + 1) * positionsInStep, addresses.size()); int stepSize = stepEndPosition - stepBeginPosition; - // index pages - for (int batchIndex = 0; batchIndex < stepSize; batchIndex++) { - int addressIndex = batchIndex + stepBeginPosition; - if (isPositionNull(addressIndex)) { - continue; - } + indexPages(addresses, positionLinks, stepBeginPosition, stepSize); + } - long address = addresses.getLong(addressIndex); - int blockIndex = decodeSliceIndex(address); - int blockPosition = decodePosition(address); - long value = joinChannelBlocks.get(blockIndex).getLong(blockPosition, 0); - - int pos = getHashPosition(value, mask); - - // look for an empty slot or a slot containing this key - while (keys[pos] != -1) { - int currentKey = keys[pos]; - if (value == values[currentKey]) { - // found a slot for this key - // link the new key position to the current key position - addressIndex = positionLinks.link(addressIndex, currentKey); - - // key[pos] updated outside of this loop - break; - } - // increment position and mask to handler wrap around - pos = (pos + 1) & mask; - } + size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() + + sizeOf(keys) + sizeOf(values); + } - keys[pos] = addressIndex; - values[addressIndex] = value; + private void indexPages(LongArrayList addresses, PositionLinks.FactoryBuilder positionLinks, int stepBeginPosition, int stepSize) + { + // index pages + for (int batchIndex = 0; batchIndex < stepSize; batchIndex++) { + int addressIndex = batchIndex + stepBeginPosition; + if (isPositionNull(addressIndex)) { + continue; } + + long address = addresses.getLong(addressIndex); + int blockIndex = decodeSliceIndex(address); + int blockPosition = decodePosition(address); + long value = joinChannelBlocks.get(blockIndex).getLong(blockPosition, 0); + + int pos = getHashPosition(value, mask); + + insertValue(positionLinks, addressIndex, value, pos); } + } - size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() + - sizeOf(keys) + sizeOf(values); + private void insertValue(PositionLinks.FactoryBuilder positionLinks, int addressIndex, long value, int pos) + { + // look for an empty slot or a slot containing this key + while (keys[pos] != -1) { + int currentKey = keys[pos]; + if (value == values[currentKey]) { + // found a slot for this key + // link the new key position to the current key position + addressIndex = positionLinks.link(addressIndex, currentKey); + + // key[pos] updated outside of this loop + break; + } + // increment position and mask to handler wrap around + pos = (pos + 1) & mask; + } + + keys[pos] = addressIndex; + values[addressIndex] = value; } @Override @@ -178,10 +188,7 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage) long[] incomingValues = new long[positionCount]; int[] hashPositions = new int[positionCount]; - for (int i = 0; i < positionCount; i++) { - incomingValues[i] = hashChannelsPage.getBlock(0).getLong(positions[i], 0); - hashPositions[i] = getHashPosition(incomingValues[i], mask); - } + extractAndHashValues(positions, hashChannelsPage, positionCount, incomingValues, hashPositions); int[] found = new int[positionCount]; int foundCount = 0; @@ -191,9 +198,7 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage) // Search for positions in the hash array. This is the most CPU-consuming part as // it relies on random memory accesses - for (int i = 0; i < positionCount; i++) { - foundKeys[i] = keys[hashPositions[i]]; - } + findPositions(positionCount, hashPositions, foundKeys); // Found positions are put into `found` array for (int i = 0; i < positionCount; i++) { if (foundKeys[i] != -1) { @@ -203,21 +208,18 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage) // At this step we determine if the found keys were indeed the proper ones or it is a hash collision. // The result array is updated for the found ones, while the collisions land into `remaining` array. + int remainingCount = checkFoundPositions(incomingValues, found, foundCount, result, foundKeys); int[] remaining = found; // Rename for readability - int remainingCount = 0; - - for (int i = 0; i < foundCount; i++) { - int index = found[i]; - if (values[foundKeys[index]] == incomingValues[index]) { - result[index] = foundKeys[index]; - } - else { - remaining[remainingCount++] = index; - } - } // At this point for any reasoable load factor of a hash array (< .75), there is no more than // 10 - 15% of positions left. We search for them in a sequential order and update the result array. + findRemainingPositions(incomingValues, hashPositions, result, remaining, remainingCount); + + return result; + } + + private void findRemainingPositions(long[] incomingValues, int[] hashPositions, int[] result, int[] remaining, int remainingCount) + { for (int i = 0; i < remainingCount; i++) { int index = remaining[i]; int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked @@ -231,8 +233,37 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage) position = (position + 1) & mask; } } + } - return result; + private int checkFoundPositions(long[] incomingValues, int[] found, int foundCount, int[] result, int[] foundKeys) + { + int[] remaining = found; // Rename for readability + int remainingCount = 0; + for (int i = 0; i < foundCount; i++) { + int index = found[i]; + if (values[foundKeys[index]] == incomingValues[index]) { + result[index] = foundKeys[index]; + } + else { + remaining[remainingCount++] = index; + } + } + return remainingCount; + } + + private void findPositions(int positionCount, int[] hashPositions, int[] foundKeys) + { + for (int i = 0; i < positionCount; i++) { + foundKeys[i] = keys[hashPositions[i]]; + } + } + + private void extractAndHashValues(int[] positions, Page hashChannelsPage, int positionCount, long[] incomingValues, int[] hashPositions) + { + for (int i = 0; i < positionCount; i++) { + incomingValues[i] = hashChannelsPage.getBlock(0).getLong(positions[i], 0); + hashPositions[i] = getHashPosition(incomingValues[i], mask); + } } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java index 4acdbdca673b..fbac2137741f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java @@ -88,44 +88,59 @@ public DefaultPagesHash( // First extract all hashes from blocks to native array. // Somehow having this as a separate loop is much faster compared // to extracting hashes on the fly in the loop below. - for (int batchIndex = 0; batchIndex < stepSize; batchIndex++) { - int addressIndex = batchIndex + stepBeginPosition; - long hash = readHashPosition(addressIndex); - positionToFullHashes[batchIndex] = hash; - positionToHashes[addressIndex] = (byte) hash; - } + extractHashes(positionToFullHashes, stepBeginPosition, stepSize); // index pages - for (int position = 0; position < stepSize; position++) { - int realPosition = position + stepBeginPosition; - if (isPositionNull(realPosition)) { - continue; - } + indexPages(positionLinks, positionToFullHashes, stepBeginPosition, stepSize); + } - long hash = positionToFullHashes[position]; - int pos = getHashPosition(hash, mask); - - // look for an empty slot or a slot containing this key - while (keys[pos] != -1) { - int currentKey = keys[pos]; - if (((byte) hash) == positionToHashes[currentKey] && positionEqualsPositionIgnoreNulls(currentKey, realPosition)) { - // found a slot for this key - // link the new key position to the current key position - realPosition = positionLinks.link(realPosition, currentKey); - - // key[pos] updated outside of this loop - break; - } - // increment position and mask to handler wrap around - pos = (pos + 1) & mask; - } + size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() + + sizeOf(keys) + sizeOf(positionToHashes); + } - keys[pos] = realPosition; + private void extractHashes(long[] positionToFullHashes, int stepBeginPosition, int stepSize) + { + for (int batchIndex = 0; batchIndex < stepSize; batchIndex++) { + int addressIndex = batchIndex + stepBeginPosition; + long hash = readHashPosition(addressIndex); + positionToFullHashes[batchIndex] = hash; + positionToHashes[addressIndex] = (byte) hash; + } + } + + private void indexPages(PositionLinks.FactoryBuilder positionLinks, long[] positionToFullHashes, int stepBeginPosition, int stepSize) + { + for (int position = 0; position < stepSize; position++) { + int realPosition = position + stepBeginPosition; + if (isPositionNull(realPosition)) { + continue; } + + long hash = positionToFullHashes[position]; + int pos = getHashPosition(hash, mask); + + insertValue(positionLinks, realPosition, (byte) hash, pos); } + } - size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() + - sizeOf(keys) + sizeOf(positionToHashes); + private void insertValue(PositionLinks.FactoryBuilder positionLinks, int realPosition, byte hash, int pos) + { + // look for an empty slot or a slot containing this key + while (keys[pos] != -1) { + int currentKey = keys[pos]; + if (hash == positionToHashes[currentKey] && positionEqualsPositionIgnoreNulls(currentKey, realPosition)) { + // found a slot for this key + // link the new key position to the current key position + realPosition = positionLinks.link(realPosition, currentKey); + + // key[pos] updated outside of this loop + break; + } + // increment position and mask to handler wrap around + pos = (pos + 1) & mask; + } + + keys[pos] = realPosition; } @Override @@ -176,11 +191,7 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage) public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) { int positionCount = positions.length; - int[] hashPositions = new int[positionCount]; - - for (int i = 0; i < positionCount; i++) { - hashPositions[i] = getHashPosition(rawHashes[positions[i]], mask); - } + int[] hashPositions = calculateHashPositions(positions, rawHashes, positionCount); int[] found = new int[positionCount]; int foundCount = 0; @@ -190,9 +201,7 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawH // Search for positions in the hash array. This is the most CPU-consuming part as // it relies on random memory accesses - for (int i = 0; i < positionCount; i++) { - foundKeys[i] = keys[hashPositions[i]]; - } + findPositions(positionCount, hashPositions, foundKeys); // Found positions are put into `found` array for (int i = 0; i < positionCount; i++) { if (foundKeys[i] != -1) { @@ -202,20 +211,18 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawH // At this step we determine if the found keys were indeed the proper ones or it is a hash collision. // The result array is updated for the found ones, while the collisions land into `remaining` array. + int remainingCount = checkFoundPositions(positions, hashChannelsPage, rawHashes, found, foundCount, result, foundKeys); int[] remaining = found; // Rename for readability - int remainingCount = 0; - for (int i = 0; i < foundCount; i++) { - int index = found[i]; - if (positionEqualsCurrentRowIgnoreNulls(foundKeys[index], (byte) rawHashes[positions[index]], positions[index], hashChannelsPage)) { - result[index] = foundKeys[index]; - } - else { - remaining[remainingCount++] = index; - } - } // At this point for any reasoable load factor of a hash array (< .75), there is no more than // 10 - 15% of positions left. We search for them in a sequential order and update the result array. + findRemainingPositions(positions, hashChannelsPage, rawHashes, hashPositions, result, remainingCount, remaining); + + return result; + } + + private void findRemainingPositions(int[] positions, Page hashChannelsPage, long[] rawHashes, int[] hashPositions, int[] result, int remainingCount, int[] remaining) + { for (int i = 0; i < remainingCount; i++) { int index = remaining[i]; int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked @@ -229,8 +236,46 @@ public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawH position = (position + 1) & mask; } } + } - return result; + private int checkFoundPositions( + int[] positions, + Page hashChannelsPage, + long[] rawHashes, + int[] found, + int foundCount, + int[] result, + int[] foundKeys) + { + int[] remaining = found; // Rename for readability + int remainingCount = 0; + for (int i = 0; i < foundCount; i++) { + int index = found[i]; + if (positionEqualsCurrentRowIgnoreNulls(foundKeys[index], (byte) rawHashes[positions[index]], positions[index], hashChannelsPage)) { + result[index] = foundKeys[index]; + } + else { + remaining[remainingCount++] = index; + } + } + return remainingCount; + } + + private void findPositions(int positionCount, int[] hashPositions, int[] foundKeys) + { + for (int i = 0; i < positionCount; i++) { + foundKeys[i] = keys[hashPositions[i]]; + } + } + + private int[] calculateHashPositions(int[] positions, long[] rawHashes, int positionCount) + { + int[] hashPositions = new int[positionCount]; + + for (int i = 0; i < positionCount; i++) { + hashPositions[i] = getHashPosition(rawHashes[positions[i]], mask); + } + return hashPositions; } @Override