Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI: Rewrite growBuffersAndRows to accelerate the HostColumnBuilder #10025

Merged
merged 11 commits into from
Jan 31, 2022
207 changes: 136 additions & 71 deletions java/src/main/java/ai/rapids/cudf/HostColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ public static HostColumnVector timestampNanoSecondsFromBoxedLongs(Long... values
* Build
*/

public static final class ColumnBuilder implements AutoCloseable {
public static final class ColumnBuilder implements AutoCloseable {

private DType type;
private HostMemoryBuffer data;
Expand All @@ -869,13 +869,13 @@ public static final class ColumnBuilder implements AutoCloseable {
private boolean nullable;
private long rows;
private long estimatedRows;
private long rowCapacity = 0L;
private boolean built = false;
private List<ColumnBuilder> childBuilders = new ArrayList<>();

private int currentIndex = 0;
private int currentByteIndex = 0;


public ColumnBuilder(HostColumnVector.DataType type, long estimatedRows) {
this.type = type.getType();
this.nullable = type.isNullable();
Expand Down Expand Up @@ -929,74 +929,127 @@ public ColumnBuilder appendStructValues(StructData... inputList) {
}

/**
* A method that is responsible for growing the buffers as needed
* and incrementing the row counts when we append values or nulls.
* @param hasNull indicates whether the validity buffer needs to be considered, as the
* nullcount may not have been fully calculated yet
* @param length used for strings
* Grows valid buffer lazily. The valid buffer won't be materialized until the first null
* value appended. This method reuses the rowCapacity to track the sizes of column.
* Therefore, please call specific growBuffer method to update rowCapacity before calling
* this method.
*/
private void growValidBuffer() {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
long maskBytes = byteSizeOfNullMask((int) rowCapacity);
revans2 marked this conversation as resolved.
Show resolved Hide resolved
if (valid == null) {
valid = HostMemoryBuffer.allocate(maskBytes);
valid.setMemory(0, valid.length, (byte) 0xFF);
} else {
HostMemoryBuffer newValid = HostMemoryBuffer.allocate(maskBytes);
newValid.setMemory(0, newValid.length, (byte) 0xFF);
valid = copyBuffer(newValid, valid);
}
}

/**
* A method automatically grows data buffer for fixed-width columns as needed along with
* incrementing the row counts. Please call this method before appending any value or null.
*/
private void growBuffersAndRows(boolean hasNull, int length) {
private void growFixedWidthBuffersAndRows() {
assert rows + 1 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE";
rows++;
long targetDataSize = 0;

if (!type.isNestedType()) {
if (type.equals(DType.STRING)) {
targetDataSize = data == null ? length : currentByteIndex + length;
} else {
targetDataSize = data == null ? estimatedRows * type.getSizeInBytes() : rows * type.getSizeInBytes();
}
if (data == null) {
data = HostMemoryBuffer.allocate(estimatedRows * type.getSizeInBytes());
rowCapacity = estimatedRows;
} else if (rows > rowCapacity) {
long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1);
data = copyBuffer(HostMemoryBuffer.allocate(newCap * type.getSizeInBytes()), data);
rowCapacity = newCap;
}
}

if (targetDataSize > 0) {
if (data == null) {
data = HostMemoryBuffer.allocate(targetDataSize);
} else {
long maxLen;
if (type.equals(DType.STRING)) {
maxLen = Integer.MAX_VALUE;
} else {
maxLen = Integer.MAX_VALUE * (long) type.getSizeInBytes();
}
long oldLen = data.getLength();
long newDataLen = Math.max(1, oldLen);
while (targetDataSize > newDataLen) {
newDataLen = newDataLen * 2;
}
if (newDataLen != oldLen) {
newDataLen = Math.min(newDataLen, maxLen);
if (newDataLen < targetDataSize) {
throw new IllegalStateException("A data buffer for strings is not supported over 2GB in size");
}
HostMemoryBuffer newData = HostMemoryBuffer.allocate(newDataLen);
data = copyBuffer(newData, data);
}
}
/**
* A method automatically grows offsets buffer for list columns as needed along with
* incrementing the row counts. Please call this method before appending any value or null.
*/
private void growListBuffersAndRows() {
assert rows + 2 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE";
rows++;

if (offsets == null) {
offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE);
offsets.setInt(0, 0);
rowCapacity = estimatedRows;
} else if (rows > rowCapacity) {
long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2);
offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets);
rowCapacity = newCap;
}
if (type.equals(DType.LIST) || type.equals(DType.STRING)) {
if (offsets == null) {
offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE);
offsets.setInt(0, 0);
} else if ((rows +1) * OFFSET_SIZE > offsets.length) {
long newOffsetLen = offsets.length * 2;
HostMemoryBuffer newOffsets = HostMemoryBuffer.allocate(newOffsetLen);
offsets = copyBuffer(newOffsets, offsets);
}
}

/**
* A method automatically grows offsets and data buffer for string columns as needed along with
* incrementing the row counts. Please call this method before appending any value or null.
*
* @param stringLength number of bytes required by the next row
*/
private void growStringBuffersAndRows(int stringLength) {
assert rows + 2 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE";
rows++;

if (offsets == null) {
// Initialize data buffer with at least 64 bytes to avoid growing too frequently.
data = HostMemoryBuffer.allocate(Math.max(64, stringLength));
offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE);
offsets.setInt(0, 0);
rowCapacity = estimatedRows;
return;
}
if (hasNull || nullCount > 0) {
if (valid == null) {
long targetValidSize = ColumnView.getNativeValidPointerSize((int)estimatedRows);
valid = HostMemoryBuffer.allocate(targetValidSize);
valid.setMemory(0, targetValidSize, (byte) 0xFF);
} else if (valid.length < ColumnView.getNativeValidPointerSize((int)rows)) {
long newValidLen = valid.length * 2;
HostMemoryBuffer newValid = HostMemoryBuffer.allocate(newValidLen);
newValid.setMemory(0, newValidLen, (byte) 0xFF);
valid = copyBuffer(newValid, valid);
}

if (rows > rowCapacity) {
long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2);
offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets);
rowCapacity = newCap;
}

long currentLength = currentByteIndex + stringLength;
long requiredLength = data.length;
while (currentLength > requiredLength) {
requiredLength = requiredLength * 2;
}
if (requiredLength > data.length) {
data = copyBuffer(HostMemoryBuffer.allocate(requiredLength), data);
}
}

/**
* For struct columns, we only need to update rows and rowCapacity (for the growth of
* valid buffer), because struct columns hold no buffer itself.
* Please call this method before appending any value or null.
*/
private void growStructBuffersAndRows() {
assert rows + 1 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE";
rows++;

if (rowCapacity == 0) {
rowCapacity = estimatedRows;
} else if (rows > rowCapacity) {
rowCapacity = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1);
}
}

/**
* The Java substitution of native method `ColumnView.getNativeValidPointerSize`.
* Ideally, this method can speed up growValidBuffer by eliminating the JNI call.
*/
private static long byteSizeOfNullMask(int numRows) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to look at replacing the JNI call entirely with this. I don't see a lot of reason to have this hidden here when we could make it common.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// number of bytes required = Math.ceil(number of bits / 8)
int actualBytes = (numRows >> 3) + Math.min(numRows % 8, 1);

// padding to multiply of 64 bytes, just as cuDF default padding boundary
long padding = actualBytes >> 6;
if (actualBytes % 64 > 0) padding++;
padding = padding << 6;
revans2 marked this conversation as resolved.
Show resolved Hide resolved

return padding;
}

private HostMemoryBuffer copyBuffer(HostMemoryBuffer targetBuffer, HostMemoryBuffer buffer) {
try {
targetBuffer.copyFromHostBuffer(0, buffer, 0, buffer.length);
Expand All @@ -1021,7 +1074,19 @@ private void setNullAt(int index) {
}

public final ColumnBuilder appendNull() {
growBuffersAndRows(true, 0);
// Increments row number. And update offsets and data buffer along with the valid buffer.
// NOTE: The growth of valid buffer must happen after rowCapacity updated.
if (type.typeId == DType.DTypeEnum.LIST) {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
growListBuffersAndRows();
} else if (type.typeId == DType.DTypeEnum.STRING) {
growStringBuffersAndRows(0);
} else if (type.getSizeInBytes() > 0) {
growFixedWidthBuffersAndRows();
} else {
growStructBuffersAndRows();
}
growValidBuffer();
revans2 marked this conversation as resolved.
Show resolved Hide resolved

setNullAt(currentIndex);
currentIndex++;
currentByteIndex += type.getSizeInBytes();
Expand Down Expand Up @@ -1081,7 +1146,7 @@ public ColumnBuilder endStruct() {
assert type.equals(DType.STRUCT) : "This only works for structs";
assert allChildrenHaveSameIndex() : "Appending structs data appears to be off " +
childBuilders + " should all have the same currentIndex " + type;
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growStructBuffersAndRows();
currentIndex++;
return this;
}
Expand All @@ -1095,7 +1160,7 @@ assert allChildrenHaveSameIndex() : "Appending structs data appears to be off "
*/
public ColumnBuilder endList() {
assert type.equals(DType.LIST);
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growListBuffersAndRows();
currentIndex++;
offsets.setInt(currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex());
return this;
Expand Down Expand Up @@ -1161,7 +1226,7 @@ public int getCurrentByteIndex() {
}

public final ColumnBuilder append(byte value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert type.isBackedByByte();
assert currentIndex < rows;
data.setByte(currentIndex * type.getSizeInBytes(), value);
Expand All @@ -1171,7 +1236,7 @@ public final ColumnBuilder append(byte value) {
}

public final ColumnBuilder append(short value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert type.isBackedByShort();
assert currentIndex < rows;
data.setShort(currentIndex * type.getSizeInBytes(), value);
Expand All @@ -1181,7 +1246,7 @@ public final ColumnBuilder append(short value) {
}

public final ColumnBuilder append(int value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert type.isBackedByInt();
assert currentIndex < rows;
data.setInt(currentIndex * type.getSizeInBytes(), value);
Expand All @@ -1191,7 +1256,7 @@ public final ColumnBuilder append(int value) {
}

public final ColumnBuilder append(long value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert type.isBackedByLong();
assert currentIndex < rows;
data.setLong(currentIndex * type.getSizeInBytes(), value);
Expand All @@ -1201,7 +1266,7 @@ public final ColumnBuilder append(long value) {
}

public final ColumnBuilder append(float value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert type.equals(DType.FLOAT32);
assert currentIndex < rows;
data.setFloat(currentIndex * type.getSizeInBytes(), value);
Expand All @@ -1211,7 +1276,7 @@ public final ColumnBuilder append(float value) {
}

public final ColumnBuilder append(double value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert type.equals(DType.FLOAT64);
assert currentIndex < rows;
data.setDouble(currentIndex * type.getSizeInBytes(), value);
Expand All @@ -1221,7 +1286,7 @@ public final ColumnBuilder append(double value) {
}

public final ColumnBuilder append(boolean value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert type.equals(DType.BOOL8);
assert currentIndex < rows;
data.setBoolean(currentIndex * type.getSizeInBytes(), value);
Expand All @@ -1231,7 +1296,7 @@ public final ColumnBuilder append(boolean value) {
}

public final ColumnBuilder append(BigDecimal value) {
growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes());
growFixedWidthBuffersAndRows();
assert currentIndex < rows;
// Rescale input decimal with UNNECESSARY policy, which accepts no precision loss.
BigInteger unscaledVal = value.setScale(-type.getScale(), RoundingMode.UNNECESSARY).unscaledValue();
Expand Down Expand Up @@ -1268,7 +1333,7 @@ public ColumnBuilder appendUTF8String(byte[] value, int srcOffset, int length) {
assert value.length + srcOffset <= length;
assert type.equals(DType.STRING) : " type " + type + " is not String";
currentIndex++;
growBuffersAndRows(false, length);
growStringBuffersAndRows(length);
assert currentIndex < rows + 1;
if (length > 0) {
data.setBytes(currentByteIndex, value, srcOffset, length);
Expand Down
1 change: 0 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3559,7 +3559,6 @@ void testCastDecimal64ToString() {
for (int scale : new int[]{-5, -2, -1, 0, 1, 2, 5}) {
for (int i = 0; i < strDecimalValues.length; i++) {
strDecimalValues[i] = dumpDecimal(unScaledValues[i], scale);
System.out.println(strDecimalValues[i]);
}

testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL64, scale),
Expand Down