-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize validity buffer concat. #2626
base: branch-25.02
Are you sure you want to change the base?
Optimize validity buffer concat. #2626
Conversation
cc @jlowe I did some benchmark and didn't notice much performance improvement. |
Signed-off-by: liurenjie1024 <[email protected]>
04d78bd
to
3cff316
Compare
build |
// Extract appendCount bits from srcByte, starting from curSrcBitIdx | ||
byte mask = (byte) (((1 << appendCount) - 1) & 0xFF); | ||
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); | ||
int totalRowCount = toIntExact(sliceInfo.getRowCount() + sliceInfo.getValidityBufferInfo().getBeginBit()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of this variable makes reading this confusing. It's not a total row count as the name implies. It's the end index or ending row, IIUC.
// Sets the bits in destination buffer starting from curDestBitIdx to 0 | ||
byte destByte = dest.getByte(curDestByteIdx); | ||
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1)); | ||
if (dest.getLength() >= (curDestOffset + Integer.BYTES)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditionals should not be in the body of the while loop. The while loop should be as simple as possible, since that's expected to be the hotspot. IMO the code should be structured into three parts similar to the following:
if (curSrcIdx % 8 != 0) {
// read an int from the buffer
// mask off the unused bits
// count the bits
// shift and store the bits
}
while (whole_ints_left_in_buffer) {
// read int from buffer
// count bits
// shift and store the bits
}
if (leftover bits) {
// read an int from the buffer (leverage padded buffer here)
// mask off the unused bits
// count the bits
// shift and store the bits
}
byte destByte = dest.getByte(curDestByteIdx); | ||
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1)); | ||
if (dest.getLength() >= (curDestOffset + Integer.BYTES)) { | ||
// We have enough room to get an int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we always have enough space to get an integer from the destination buffer because it's at least 4-byte padded?
byte[] destBytes = new byte[4]; | ||
dest.getBytes(destBytes, 0, curDestOffset, destBufRemBytes); | ||
int destInt = ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).getInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why we're doing byte-at-time here and endian stuff when we don't up above? This is still grabbing 4 bytes like the above code. Byte-at-a-time is usually a lot slower, might be faster to read the int from the buffer (we're loading 4 bytes anyway) and call Integer.reverseBytes (a HotSpot intrinsic candidate) if ByteOrder.nativeOrder == ByteOrder.LITTLE_ENDIAN.
ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).putInt(destInt); | ||
dest.setBytes(curDestOffset, destBytes, 0, destBufRemBytes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to above, can call setInt here and leverage ByteOrder to determine when we need to swap bytes or not.
Please update the PR description. It will be displayed in the commit log. Simply saying "closes XXX" will just show that words, which is difficult to track the changes through commit log. |
To be fair we do not have the automation (yet) of checking in PRs using the PR description. I follow the convention of copying the PR description as the commit message but it's not mandated and not followed widely in NVIDIA/spark* repos. |
Could we copy by long instead of int? We can avoid to use |
// Sets the bits in destination buffer starting from curDestBitIdx to 0 | ||
byte destByte = dest.getByte(curDestByteIdx); | ||
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF); | ||
while (curSrcIdx < totalRowCount) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So does this concatenate the validity buffers, one word (32 bits) at a time, in a serial manner, in Java?
Can't we use C++/CUDA for accelerating this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a separate effort by @nvdbaranec to provide GPU kernels for Kudo buffer concatenation. We need both CPU and GPU implementations. In some cases, the GPU will be the bottleneck for a stage, and CPUs will be idle while waiting for the GPU semaphore. It will be more efficient to concatenate on the CPU, even if it's slower than the GPU, when the CPU can complete the operation within the time otherwise spent waiting for the GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've never seen any parallel Java code in our plugin/JNI. Assuming that it is doable, would we consider doing so for the operations like this?
// Update destination byte with the bits from source byte | ||
destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF); | ||
dest.setByte(curDestByteIdx, destByte); | ||
int curDestOffset = (curDestIdx / 32) * Integer.BYTES; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using mnemonic constants
int curDestOffset = (curDestIdx / 32) * Integer.BYTES; | |
int curDestOffset = (curDestIdx / Integer.SIZE) * Integer.BYTES; |
or even
int curDestOffset = (curDestIdx / 32) * Integer.BYTES; | |
int curDestOffset = curDestIdx / Byte.SIZE; |
Close #2579