Skip to content

Commit

Permalink
Optimize validity buffer concat.
Browse files Browse the repository at this point in the history
  • Loading branch information
liurenjie1024 committed Nov 26, 2024
1 parent a02c564 commit 04d78bd
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 90 deletions.
149 changes: 60 additions & 89 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,36 @@

package com.nvidia.spark.rapids.jni.kudo;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

import ai.rapids.cudf.HostMemoryBuffer;
import ai.rapids.cudf.Schema;
import com.nvidia.spark.rapids.jni.Arms;
import com.nvidia.spark.rapids.jni.schema.Visitors;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalInt;

import static com.nvidia.spark.rapids.jni.Preconditions.ensure;
import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

/**
* This class is used to merge multiple KudoTables into a single contiguous buffer, e.g. {@link KudoHostMergeResult},
* which could be easily converted to a {@link ai.rapids.cudf.ContiguousTable}.
*/
class KudoTableMerger extends MultiKudoTableVisitor<Void, Void, KudoHostMergeResult> {
// Number of 1s in a byte
private static final int[] NUMBER_OF_ONES = new int[256];

static {
for (int i = 0; i < NUMBER_OF_ONES.length; i += 1) {
int count = 0;
for (int j = 0; j < 8; j += 1) {
if ((i & (1 << j)) != 0) {
count += 1;
}
}
NUMBER_OF_ONES[i] = count;
}
}

private final List<ColumnOffsetInfo> columnOffsets;
private final HostMemoryBuffer buffer;
private final List<ColumnViewInfo> colViewInfoList;

public KudoTableMerger(List<KudoTable> tables, HostMemoryBuffer buffer, List<ColumnOffsetInfo> columnOffsets) {
public KudoTableMerger(List<KudoTable> tables, HostMemoryBuffer buffer,
List<ColumnOffsetInfo> columnOffsets) {
super(tables);
requireNonNull(buffer, "buffer can't be null!");
ensure(columnOffsets != null, "column offsets cannot be null");
Expand Down Expand Up @@ -155,80 +141,64 @@ private static int copyValidityBuffer(HostMemoryBuffer dest, int startBit,
HostMemoryBuffer src, int srcOffset,
SliceInfo sliceInfo) {
int nullCount = 0;
int totalRowCount = sliceInfo.getRowCount();
int curIdx = 0;
int curSrcByteIdx = srcOffset;
int curSrcBitIdx = sliceInfo.getValidityBufferInfo().getBeginBit();
int curDestByteIdx = startBit / 8;
int curDestBitIdx = startBit % 8;

while (curIdx < totalRowCount) {
int leftRowCount = totalRowCount - curIdx;
int appendCount;
if (curDestBitIdx == 0) {
appendCount = min(8, leftRowCount);
} else {
appendCount = min(8 - curDestBitIdx, leftRowCount);
}

int leftBitsInCurSrcByte = 8 - curSrcBitIdx;
byte srcByte = src.getByte(curSrcByteIdx);
if (leftBitsInCurSrcByte >= appendCount) {
// Extract appendCount bits from srcByte, starting from curSrcBitIdx
byte mask = (byte) (((1 << appendCount) - 1) & 0xFF);
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask);
int totalRowCount = toIntExact(sliceInfo.getRowCount() + sliceInfo.getValidityBufferInfo().getBeginBit());
int curSrcIdx = sliceInfo.getValidityBufferInfo().getBeginBit();
int curDestIdx = startBit;

nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]);

// Sets the bits in destination buffer starting from curDestBitIdx to 0
byte destByte = dest.getByte(curDestByteIdx);
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF);
while (curSrcIdx < totalRowCount) {
int leftRowCount = totalRowCount - curSrcIdx;

// Update destination byte with the bits from source byte
destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF);
dest.setByte(curDestByteIdx, destByte);
int curDestOffset = (curDestIdx / 32) * Integer.BYTES;
int curDestBitIdx = curDestIdx % 32;

curSrcBitIdx += appendCount;
if (curSrcBitIdx == 8) {
curSrcBitIdx = 0;
curSrcByteIdx += 1;
}
} else {
// Extract appendCount bits from srcByte, starting from curSrcBitIdx
byte mask = (byte) (((1 << leftBitsInCurSrcByte) - 1) & 0xFF);
srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask);

byte nextSrcByte = src.getByte(curSrcByteIdx + 1);
byte nextSrcByteMask = (byte) ((1 << (appendCount - leftBitsInCurSrcByte)) - 1);
nextSrcByte = (byte) (nextSrcByte & nextSrcByteMask);
nextSrcByte = (byte) (nextSrcByte << leftBitsInCurSrcByte);
srcByte = (byte) (srcByte | nextSrcByte);
int curSrcOffset = srcOffset + (curSrcIdx / 32) * Integer.BYTES;
int curSrcBitIdx = curSrcIdx % 32;

nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]);
// This is safe since we always have validity buffer 4 bytes padded
int srcInt = src.getInt(curSrcOffset);
srcInt = srcInt >>> curSrcBitIdx;

// Sets the bits in destination buffer starting from curDestBitIdx to 0
byte destByte = dest.getByte(curDestByteIdx);
destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1));
if (dest.getLength() >= (curDestOffset + Integer.BYTES)) {
// We have enough room to get an int
int destInt = dest.getInt(curDestOffset);
destInt &= (1 << curDestBitIdx) - 1;
destInt |= srcInt << curDestBitIdx;
dest.setInt(curDestOffset, destInt);

// Update destination byte with the bits from source byte
destByte = (byte) (destByte | (srcByte << curDestBitIdx));
dest.setByte(curDestByteIdx, destByte);

// Update the source byte index and bit index
curSrcByteIdx += 1;
curSrcBitIdx = appendCount - leftBitsInCurSrcByte;
}
int appendCount = min(leftRowCount, 32 - Math.max(curSrcBitIdx, curDestBitIdx));

curIdx += appendCount;

// Update the destination byte index and bit index
curDestBitIdx += appendCount;
if (curDestBitIdx == 8) {
curDestBitIdx = 0;
curDestByteIdx += 1;
curDestIdx += appendCount;
curSrcIdx += appendCount;
if (appendCount == 32) {
nullCount += 32 - Integer.bitCount(srcInt);
} else {
int mask = (1 << appendCount) - 1;
nullCount += (appendCount - Integer.bitCount(srcInt & mask));
}
} else {
int destBufRemBytes = toIntExact(dest.getLength() - curDestOffset);
byte[] destBytes = new byte[4];
dest.getBytes(destBytes, 0, curDestOffset, destBufRemBytes);
int destInt = ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).getInt();
destInt &= (1 << curDestBitIdx) - 1;
destInt |= srcInt << curDestBitIdx;

ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).putInt(destInt);
dest.setBytes(curDestOffset, destBytes, 0, destBufRemBytes);

int appendCount = min(leftRowCount, destBufRemBytes * 8 - Math.max(curSrcBitIdx, curDestBitIdx));

curDestIdx += appendCount;
curSrcIdx += appendCount;
int mask = (1 << appendCount) - 1;
nullCount += (appendCount - Integer.bitCount(srcInt & mask));
}
}

int srcIdx = curSrcIdx;
ensure(curSrcIdx == totalRowCount, () -> "Did not copy all of the validity buffer, total row count: " + totalRowCount +
" current src idx: " + srcIdx);
return nullCount;
}

Expand Down Expand Up @@ -325,7 +295,8 @@ static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) {
List<KudoTable> serializedTables = mergedInfo.getTables();
return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()),
buffer -> {
KudoTableMerger merger = new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets());
KudoTableMerger merger =
new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets());
return Visitors.visitSchema(schema, merger);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public class KudoSerializerTest {
public void testSerializeAndDeserializeTable() {
try(Table expected = buildTestTable()) {
int rowCount = toIntExact(expected.getRowCount());
for (int sliceSize = 1; sliceSize <= rowCount; sliceSize++) {
IntStream sliceSizes = IntStream.range(1, rowCount + 1);
for (int sliceSize: sliceSizes.toArray()) {
List<TableSlice> tableSlices = new ArrayList<>();
for (int startRow = 0; startRow < rowCount; startRow += sliceSize) {
tableSlices.add(new TableSlice(startRow, Math.min(sliceSize, rowCount - startRow), expected));
Expand Down

0 comments on commit 04d78bd

Please sign in to comment.