From 77119a922948674ef1a9b4defa48e062f88fc45e Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 12 Jan 2022 13:14:26 +0800 Subject: [PATCH 1/9] draft --- .../java/ai/rapids/cudf/HostColumnVector.java | 202 +++++++++++------- 1 file changed, 130 insertions(+), 72 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index e21a4ac81c6..56382e47604 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -858,7 +858,7 @@ public static HostColumnVector timestampNanoSecondsFromBoxedLongs(Long... values * Build */ - public static final class ColumnBuilder implements AutoCloseable { + public static final class ColumnBuilder implements AutoCloseable { private DType type; private HostMemoryBuffer data; @@ -868,8 +868,10 @@ public static final class ColumnBuilder implements AutoCloseable { //TODO nullable currently not used private boolean nullable; private long rows; - private long estimatedRows; + private long initialRows; + private long rowCapacity = 0L; private boolean built = false; + private boolean fixedLenBuffers = false; private List childBuilders = new ArrayList<>(); private int currentIndex = 0; @@ -877,12 +879,32 @@ public static final class ColumnBuilder implements AutoCloseable { public ColumnBuilder(HostColumnVector.DataType type, long estimatedRows) { + new ColumnBuilder(type, estimatedRows, false); + } + + public ColumnBuilder(HostColumnVector.DataType type, + long initialRows, + boolean useFixedLenBufferIfPossible) { this.type = type.getType(); this.nullable = type.isNullable(); this.rows = 0; - this.estimatedRows = estimatedRows; + this.initialRows = initialRows; + + if (useFixedLenBufferIfPossible && this.type.typeId != DType.DTypeEnum.STRING) { + this.fixedLenBuffers = true; + // Pre-Allocate fixed length buffers in case of time-consuming buffer growing. + if (this.type.typeId == DType.DTypeEnum.STRUCT) { + growStructBuffersAndRows(); + } else if (this.type.typeId == DType.DTypeEnum.LIST) { + growListBuffersAndRows(); + } else { + growFixedWidthBuffersAndRows(); + } + } + for (int i = 0; i < type.getNumChildren(); i++) { - childBuilders.add(new ColumnBuilder(type.getChild(i), estimatedRows)); + childBuilders.add(new ColumnBuilder( + type.getChild(i), initialRows, useFixedLenBufferIfPossible)); } } @@ -928,75 +950,102 @@ public ColumnBuilder appendStructValues(StructData... inputList) { return this; } + private void reallocateValidBuffer(long desiredRows) { + long desiredMaskBytes = byteSizeOfNullMask((int) desiredRows); + if (valid == null) { + valid = HostMemoryBuffer.allocate(desiredMaskBytes); + } else { + valid = copyBuffer(HostMemoryBuffer.allocate(desiredMaskBytes), valid); + } + } + /** * A method that is responsible for growing the buffers as needed * and incrementing the row counts when we append values or nulls. - * @param hasNull indicates whether the validity buffer needs to be considered, as the - * nullcount may not have been fully calculated yet - * @param length used for strings */ - private void growBuffersAndRows(boolean hasNull, int length) { + private void growFixedWidthBuffersAndRows() { assert rows + 1 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; rows++; - long targetDataSize = 0; - if (!type.isNestedType()) { - if (type.equals(DType.STRING)) { - targetDataSize = data == null ? length : currentByteIndex + length; - } else { - targetDataSize = data == null ? estimatedRows * type.getSizeInBytes() : rows * type.getSizeInBytes(); - } + if (data == null) { + data = HostMemoryBuffer.allocate(initialRows * type.getSizeInBytes()); + if (nullable) reallocateValidBuffer(initialRows); + rowCapacity = initialRows; + } else if (rows > rowCapacity) { + long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1); + data = copyBuffer(HostMemoryBuffer.allocate(newCap * type.getSizeInBytes()), data); + if (nullable) reallocateValidBuffer(newCap); + rowCapacity = newCap; } + } - if (targetDataSize > 0) { - if (data == null) { - data = HostMemoryBuffer.allocate(targetDataSize); - } else { - long maxLen; - if (type.equals(DType.STRING)) { - maxLen = Integer.MAX_VALUE; - } else { - maxLen = Integer.MAX_VALUE * (long) type.getSizeInBytes(); - } - long oldLen = data.getLength(); - long newDataLen = Math.max(1, oldLen); - while (targetDataSize > newDataLen) { - newDataLen = newDataLen * 2; - } - if (newDataLen != oldLen) { - newDataLen = Math.min(newDataLen, maxLen); - if (newDataLen < targetDataSize) { - throw new IllegalStateException("A data buffer for strings is not supported over 2GB in size"); - } - HostMemoryBuffer newData = HostMemoryBuffer.allocate(newDataLen); - data = copyBuffer(newData, data); - } - } + private void growListBuffersAndRows() { + assert rows + 2 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; + rows++; + + if (data == null) { + offsets = HostMemoryBuffer.allocate((initialRows + 1) * OFFSET_SIZE); + offsets.setInt(0, 0); + if (nullable) reallocateValidBuffer(initialRows); + rowCapacity = initialRows; + } else if (rows > rowCapacity) { + long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); + offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets); + if (nullable) reallocateValidBuffer(newCap); + rowCapacity = newCap; } - if (type.equals(DType.LIST) || type.equals(DType.STRING)) { - if (offsets == null) { - offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE); - offsets.setInt(0, 0); - } else if ((rows +1) * OFFSET_SIZE > offsets.length) { - long newOffsetLen = offsets.length * 2; - HostMemoryBuffer newOffsets = HostMemoryBuffer.allocate(newOffsetLen); - offsets = copyBuffer(newOffsets, offsets); - } + } + + private void growStringBuffersAndRows(int stringLength) { + assert rows + 2 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; + rows++; + + if (data == null) { + data = HostMemoryBuffer.allocate(stringLength); + offsets = HostMemoryBuffer.allocate((initialRows + 1) * OFFSET_SIZE); + offsets.setInt(0, 0); + if (nullable) reallocateValidBuffer(initialRows); + rowCapacity = initialRows; + return; } - if (hasNull || nullCount > 0) { - if (valid == null) { - long targetValidSize = ColumnView.getNativeValidPointerSize((int)estimatedRows); - valid = HostMemoryBuffer.allocate(targetValidSize); - valid.setMemory(0, targetValidSize, (byte) 0xFF); - } else if (valid.length < ColumnView.getNativeValidPointerSize((int)rows)) { - long newValidLen = valid.length * 2; - HostMemoryBuffer newValid = HostMemoryBuffer.allocate(newValidLen); - newValid.setMemory(0, newValidLen, (byte) 0xFF); - valid = copyBuffer(newValid, valid); - } + if (rows > rowCapacity) { + long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); + offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets); + if (nullable) reallocateValidBuffer(newCap); + rowCapacity = newCap; + } + if (currentByteIndex + stringLength > data.length) { + data = copyBuffer(HostMemoryBuffer.allocate(data.length * 2), data); + } + } + + private void growStructBuffersAndRows() { + assert rows + 1 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; + rows++; + + if (!nullable) return; + if (valid == null) { + reallocateValidBuffer(initialRows); + rowCapacity = initialRows; + } else if (rows > rowCapacity) { + long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1); + reallocateValidBuffer(newCap); + rowCapacity = newCap; } } + private static long byteSizeOfNullMask(int numRows) { + // number of bytes required = Math.ceil(number of bits / 8) + int actualBytes = (numRows >> 3) + Math.min(numRows % 8, 1); + + // padding to multiply of 64 bytes, just as cuDF default padding boundary + long padding = actualBytes >> 6; + if (actualBytes % 64 > 0) padding++; + padding = padding << 6; + + return padding; + } + private HostMemoryBuffer copyBuffer(HostMemoryBuffer targetBuffer, HostMemoryBuffer buffer) { try { targetBuffer.copyFromHostBuffer(0, buffer, 0, buffer.length); @@ -1021,7 +1070,16 @@ private void setNullAt(int index) { } public final ColumnBuilder appendNull() { - growBuffersAndRows(true, 0); + if (!fixedLenBuffers) { + if (type.hasOffsets()) { + // We can reuse growListBuffers on StringType, if we ensure data buffer won't grow. + growListBuffersAndRows(); + } else if (type.getSizeInBytes() > 0) { + growFixedWidthBuffersAndRows(); + } else { + growStructBuffersAndRows(); + } + } setNullAt(currentIndex); currentIndex++; currentByteIndex += type.getSizeInBytes(); @@ -1081,7 +1139,7 @@ public ColumnBuilder endStruct() { assert type.equals(DType.STRUCT) : "This only works for structs"; assert allChildrenHaveSameIndex() : "Appending structs data appears to be off " + childBuilders + " should all have the same currentIndex " + type; - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growStructBuffersAndRows(); currentIndex++; return this; } @@ -1095,7 +1153,7 @@ assert allChildrenHaveSameIndex() : "Appending structs data appears to be off " */ public ColumnBuilder endList() { assert type.equals(DType.LIST); - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growListBuffersAndRows(); currentIndex++; offsets.setInt(currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); return this; @@ -1161,7 +1219,7 @@ public int getCurrentByteIndex() { } public final ColumnBuilder append(byte value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert type.isBackedByByte(); assert currentIndex < rows; data.setByte(currentIndex * type.getSizeInBytes(), value); @@ -1171,7 +1229,7 @@ public final ColumnBuilder append(byte value) { } public final ColumnBuilder append(short value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert type.isBackedByShort(); assert currentIndex < rows; data.setShort(currentIndex * type.getSizeInBytes(), value); @@ -1181,7 +1239,7 @@ public final ColumnBuilder append(short value) { } public final ColumnBuilder append(int value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert type.isBackedByInt(); assert currentIndex < rows; data.setInt(currentIndex * type.getSizeInBytes(), value); @@ -1191,7 +1249,7 @@ public final ColumnBuilder append(int value) { } public final ColumnBuilder append(long value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert type.isBackedByLong(); assert currentIndex < rows; data.setLong(currentIndex * type.getSizeInBytes(), value); @@ -1201,7 +1259,7 @@ public final ColumnBuilder append(long value) { } public final ColumnBuilder append(float value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT32); assert currentIndex < rows; data.setFloat(currentIndex * type.getSizeInBytes(), value); @@ -1211,7 +1269,7 @@ public final ColumnBuilder append(float value) { } public final ColumnBuilder append(double value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT64); assert currentIndex < rows; data.setDouble(currentIndex * type.getSizeInBytes(), value); @@ -1221,7 +1279,7 @@ public final ColumnBuilder append(double value) { } public final ColumnBuilder append(boolean value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert type.equals(DType.BOOL8); assert currentIndex < rows; data.setBoolean(currentIndex * type.getSizeInBytes(), value); @@ -1231,7 +1289,7 @@ public final ColumnBuilder append(boolean value) { } public final ColumnBuilder append(BigDecimal value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); assert currentIndex < rows; // Rescale input decimal with UNNECESSARY policy, which accepts no precision loss. BigInteger unscaledVal = value.setScale(-type.getScale(), RoundingMode.UNNECESSARY).unscaledValue(); @@ -1268,7 +1326,7 @@ public ColumnBuilder appendUTF8String(byte[] value, int srcOffset, int length) { assert value.length + srcOffset <= length; assert type.equals(DType.STRING) : " type " + type + " is not String"; currentIndex++; - growBuffersAndRows(false, length); + growStringBuffersAndRows(length); assert currentIndex < rows + 1; if (length > 0) { data.setBytes(currentByteIndex, value, srcOffset, length); @@ -1326,7 +1384,7 @@ public String toString() { ", valid=" + valid + ", currentIndex=" + currentIndex + ", nullCount=" + nullCount + - ", estimatedRows=" + estimatedRows + + ", estimatedRows=" + initialRows + ", populatedRows=" + rows + ", built=" + built + '}'; From 5e57c4b82fde4b1c0cf07b1117cf145585257e3c Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 12 Jan 2022 18:04:52 +0800 Subject: [PATCH 2/9] rewrite the growBuffersAndRows of HostColumnVector.ColumnBuilder Signed-off-by: sperlingxx --- .../java/ai/rapids/cudf/HostColumnVector.java | 157 +++++++++--------- .../java/ai/rapids/cudf/ColumnVectorTest.java | 1 - 2 files changed, 82 insertions(+), 76 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 56382e47604..3bb3eb77984 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -868,43 +868,21 @@ public static final class ColumnBuilder implements AutoCloseable { //TODO nullable currently not used private boolean nullable; private long rows; - private long initialRows; + private long estimatedRows; private long rowCapacity = 0L; private boolean built = false; - private boolean fixedLenBuffers = false; private List childBuilders = new ArrayList<>(); private int currentIndex = 0; private int currentByteIndex = 0; - public ColumnBuilder(HostColumnVector.DataType type, long estimatedRows) { - new ColumnBuilder(type, estimatedRows, false); - } - - public ColumnBuilder(HostColumnVector.DataType type, - long initialRows, - boolean useFixedLenBufferIfPossible) { this.type = type.getType(); this.nullable = type.isNullable(); this.rows = 0; - this.initialRows = initialRows; - - if (useFixedLenBufferIfPossible && this.type.typeId != DType.DTypeEnum.STRING) { - this.fixedLenBuffers = true; - // Pre-Allocate fixed length buffers in case of time-consuming buffer growing. - if (this.type.typeId == DType.DTypeEnum.STRUCT) { - growStructBuffersAndRows(); - } else if (this.type.typeId == DType.DTypeEnum.LIST) { - growListBuffersAndRows(); - } else { - growFixedWidthBuffersAndRows(); - } - } - + this.estimatedRows = estimatedRows; for (int i = 0; i < type.getNumChildren(); i++) { - childBuilders.add(new ColumnBuilder( - type.getChild(i), initialRows, useFixedLenBufferIfPossible)); + childBuilders.add(new ColumnBuilder(type.getChild(i), estimatedRows)); } } @@ -950,90 +928,116 @@ public ColumnBuilder appendStructValues(StructData... inputList) { return this; } - private void reallocateValidBuffer(long desiredRows) { - long desiredMaskBytes = byteSizeOfNullMask((int) desiredRows); + /** + * Grows valid buffer lazily. The valid buffer won't be materialized until the first null + * value appended. This method reuses the rowCapacity to track the sizes of column. + * Therefore, please call specific growBuffer method to update rowCapacity before calling + * this method. + */ + private void growValidBuffer() { + long maskBytes = byteSizeOfNullMask((int) rowCapacity); if (valid == null) { - valid = HostMemoryBuffer.allocate(desiredMaskBytes); + valid = HostMemoryBuffer.allocate(maskBytes); + valid.setMemory(0, valid.length, (byte) 0xFF); } else { - valid = copyBuffer(HostMemoryBuffer.allocate(desiredMaskBytes), valid); + HostMemoryBuffer newValid = HostMemoryBuffer.allocate(maskBytes); + newValid.setMemory(0, newValid.length, (byte) 0xFF); + valid = copyBuffer(newValid, valid); } } /** - * A method that is responsible for growing the buffers as needed - * and incrementing the row counts when we append values or nulls. + * A method automatically grows data buffer for fixed-width columns as needed along with + * incrementing the row counts. Please call this method before appending any value or null. */ private void growFixedWidthBuffersAndRows() { assert rows + 1 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; rows++; if (data == null) { - data = HostMemoryBuffer.allocate(initialRows * type.getSizeInBytes()); - if (nullable) reallocateValidBuffer(initialRows); - rowCapacity = initialRows; + data = HostMemoryBuffer.allocate(estimatedRows * type.getSizeInBytes()); + rowCapacity = estimatedRows; } else if (rows > rowCapacity) { long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1); data = copyBuffer(HostMemoryBuffer.allocate(newCap * type.getSizeInBytes()), data); - if (nullable) reallocateValidBuffer(newCap); rowCapacity = newCap; } } + /** + * A method automatically grows offsets buffer for list columns as needed along with + * incrementing the row counts. Please call this method before appending any value or null. + */ private void growListBuffersAndRows() { assert rows + 2 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; rows++; - if (data == null) { - offsets = HostMemoryBuffer.allocate((initialRows + 1) * OFFSET_SIZE); + if (offsets == null) { + offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE); offsets.setInt(0, 0); - if (nullable) reallocateValidBuffer(initialRows); - rowCapacity = initialRows; + rowCapacity = estimatedRows; } else if (rows > rowCapacity) { long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets); - if (nullable) reallocateValidBuffer(newCap); rowCapacity = newCap; } } + /** + * A method automatically grows offsets and data buffer for string columns as needed along with + * incrementing the row counts. Please call this method before appending any value or null. + * + * @param stringLength number of bytes required by the next row + */ private void growStringBuffersAndRows(int stringLength) { assert rows + 2 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; rows++; - if (data == null) { - data = HostMemoryBuffer.allocate(stringLength); - offsets = HostMemoryBuffer.allocate((initialRows + 1) * OFFSET_SIZE); + if (offsets == null) { + // Initialize data buffer with at least 64 bytes to avoid growing too frequently. + data = HostMemoryBuffer.allocate(Math.max(64, stringLength)); + offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE); offsets.setInt(0, 0); - if (nullable) reallocateValidBuffer(initialRows); - rowCapacity = initialRows; + rowCapacity = estimatedRows; return; } + if (rows > rowCapacity) { long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets); - if (nullable) reallocateValidBuffer(newCap); rowCapacity = newCap; } - if (currentByteIndex + stringLength > data.length) { - data = copyBuffer(HostMemoryBuffer.allocate(data.length * 2), data); + + long currentLength = currentByteIndex + stringLength; + long requiredLength = data.length; + while (currentLength > requiredLength) { + requiredLength = requiredLength * 2; + } + if (requiredLength > data.length) { + data = copyBuffer(HostMemoryBuffer.allocate(requiredLength), data); } } + /** + * For struct columns, we only need to update rows and rowCapacity (for the growth of + * valid buffer), because struct columns hold no buffer itself. + * Please call this method before appending any value or null. + */ private void growStructBuffersAndRows() { assert rows + 1 <= Integer.MAX_VALUE : "Row count cannot go over Integer.MAX_VALUE"; rows++; - if (!nullable) return; - if (valid == null) { - reallocateValidBuffer(initialRows); - rowCapacity = initialRows; + if (rowCapacity == 0) { + rowCapacity = estimatedRows; } else if (rows > rowCapacity) { - long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1); - reallocateValidBuffer(newCap); - rowCapacity = newCap; + rowCapacity = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1); } } + /** + * The Java substitution of native method `ColumnView.getNativeValidPointerSize`. + * Ideally, this method can speed up growValidBuffer by eliminating the JNI call. + */ private static long byteSizeOfNullMask(int numRows) { // number of bytes required = Math.ceil(number of bits / 8) int actualBytes = (numRows >> 3) + Math.min(numRows % 8, 1); @@ -1070,16 +1074,19 @@ private void setNullAt(int index) { } public final ColumnBuilder appendNull() { - if (!fixedLenBuffers) { - if (type.hasOffsets()) { - // We can reuse growListBuffers on StringType, if we ensure data buffer won't grow. - growListBuffersAndRows(); - } else if (type.getSizeInBytes() > 0) { - growFixedWidthBuffersAndRows(); - } else { - growStructBuffersAndRows(); - } + // Increments row number. And update offsets and data buffer along with the valid buffer. + // NOTE: The growth of valid buffer must happen after rowCapacity updated. + if (type.typeId == DType.DTypeEnum.LIST) { + growListBuffersAndRows(); + } else if (type.typeId == DType.DTypeEnum.STRING) { + growStringBuffersAndRows(0); + } else if (type.getSizeInBytes() > 0) { + growFixedWidthBuffersAndRows(); + } else { + growStructBuffersAndRows(); } + growValidBuffer(); + setNullAt(currentIndex); currentIndex++; currentByteIndex += type.getSizeInBytes(); @@ -1139,7 +1146,7 @@ public ColumnBuilder endStruct() { assert type.equals(DType.STRUCT) : "This only works for structs"; assert allChildrenHaveSameIndex() : "Appending structs data appears to be off " + childBuilders + " should all have the same currentIndex " + type; - if (!fixedLenBuffers) growStructBuffersAndRows(); + growStructBuffersAndRows(); currentIndex++; return this; } @@ -1153,7 +1160,7 @@ assert allChildrenHaveSameIndex() : "Appending structs data appears to be off " */ public ColumnBuilder endList() { assert type.equals(DType.LIST); - if (!fixedLenBuffers) growListBuffersAndRows(); + growListBuffersAndRows(); currentIndex++; offsets.setInt(currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); return this; @@ -1219,7 +1226,7 @@ public int getCurrentByteIndex() { } public final ColumnBuilder append(byte value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert type.isBackedByByte(); assert currentIndex < rows; data.setByte(currentIndex * type.getSizeInBytes(), value); @@ -1229,7 +1236,7 @@ public final ColumnBuilder append(byte value) { } public final ColumnBuilder append(short value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert type.isBackedByShort(); assert currentIndex < rows; data.setShort(currentIndex * type.getSizeInBytes(), value); @@ -1239,7 +1246,7 @@ public final ColumnBuilder append(short value) { } public final ColumnBuilder append(int value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert type.isBackedByInt(); assert currentIndex < rows; data.setInt(currentIndex * type.getSizeInBytes(), value); @@ -1249,7 +1256,7 @@ public final ColumnBuilder append(int value) { } public final ColumnBuilder append(long value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert type.isBackedByLong(); assert currentIndex < rows; data.setLong(currentIndex * type.getSizeInBytes(), value); @@ -1259,7 +1266,7 @@ public final ColumnBuilder append(long value) { } public final ColumnBuilder append(float value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT32); assert currentIndex < rows; data.setFloat(currentIndex * type.getSizeInBytes(), value); @@ -1269,7 +1276,7 @@ public final ColumnBuilder append(float value) { } public final ColumnBuilder append(double value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT64); assert currentIndex < rows; data.setDouble(currentIndex * type.getSizeInBytes(), value); @@ -1279,7 +1286,7 @@ public final ColumnBuilder append(double value) { } public final ColumnBuilder append(boolean value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert type.equals(DType.BOOL8); assert currentIndex < rows; data.setBoolean(currentIndex * type.getSizeInBytes(), value); @@ -1289,7 +1296,7 @@ public final ColumnBuilder append(boolean value) { } public final ColumnBuilder append(BigDecimal value) { - if (!fixedLenBuffers) growFixedWidthBuffersAndRows(); + growFixedWidthBuffersAndRows(); assert currentIndex < rows; // Rescale input decimal with UNNECESSARY policy, which accepts no precision loss. BigInteger unscaledVal = value.setScale(-type.getScale(), RoundingMode.UNNECESSARY).unscaledValue(); @@ -1384,7 +1391,7 @@ public String toString() { ", valid=" + valid + ", currentIndex=" + currentIndex + ", nullCount=" + nullCount + - ", estimatedRows=" + initialRows + + ", estimatedRows=" + estimatedRows + ", populatedRows=" + rows + ", built=" + built + '}'; diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 8d4bbff1542..312e91145fb 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3559,7 +3559,6 @@ void testCastDecimal64ToString() { for (int scale : new int[]{-5, -2, -1, 0, 1, 2, 5}) { for (int i = 0; i < strDecimalValues.length; i++) { strDecimalValues[i] = dumpDecimal(unScaledValues[i], scale); - System.out.println(strDecimalValues[i]); } testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL64, scale), From 1777542d3135b625706425edc436bd966261728e Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Thu, 13 Jan 2022 17:00:34 +0800 Subject: [PATCH 3/9] update --- .../main/java/ai/rapids/cudf/ColumnView.java | 22 ++++--- .../java/ai/rapids/cudf/HostColumnVector.java | 63 +++++++++++-------- java/src/main/native/src/ColumnViewJni.cpp | 12 +--- 3 files changed, 51 insertions(+), 46 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index a2e080e02f6..b6c5aa86e90 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3234,7 +3234,7 @@ public final ColumnVector listIndexOf(Scalar key, FindOptions findOption) { * The index is set to null if one of the following is true: * 1. The search key row is null. * 2. The list row is null. - * @param key ColumnView of search keys. + * @param keys ColumnView of search keys. * @param findOption Whether to find the first index of the key, or the last. * @return The resultant column of int32 indices */ @@ -3270,6 +3270,17 @@ public final Scalar getScalarElement(int index) { return new Scalar(getType(), getElement(getNativeView(), index)); } + /** + * Get the number of bytes needed to allocate a validity buffer for the given number of rows. + * According to cudf::bitmask_allocation_size_bytes, the padding boundary for null mask is 64 bytes. + */ + public static long getValidityBufferSize(int numRows) { + // number of bytes required = Math.ceil(number of bits / 8) + long actualBytes = ((long) numRows + 7) >> 3; + // padding to the multiplies of the padding boundary(64 bytes) + return ((actualBytes + 63) >> 6) << 6; + } + ///////////////////////////////////////////////////////////////////////////// // INTERNAL/NATIVE ACCESS ///////////////////////////////////////////////////////////////////////////// @@ -3866,11 +3877,6 @@ private static native long bitwiseMergeAndSetValidity(long baseHandle, long[] vi private static native long copyWithBooleanColumnAsValidity(long exemplarViewHandle, long boolColumnViewHandle) throws CudfException; - /** - * Get the number of bytes needed to allocate a validity buffer for the given number of rows. - */ - static native long getNativeValidPointerSize(int size); - //////// // Native cudf::column_view life cycle and metadata access methods. Life cycle methods // should typically only be called from the OffHeap inner class. @@ -3960,7 +3966,7 @@ static ColumnVector createColumnVector(DType type, int rows, HostMemoryBuffer da DeviceMemoryBuffer mainValidDevBuff = null; DeviceMemoryBuffer mainOffsetsDevBuff = null; if (mainColValid != null) { - long validLen = getNativeValidPointerSize(mainColRows); + long validLen = getValidityBufferSize(mainColRows); mainValidDevBuff = DeviceMemoryBuffer.allocate(validLen); mainValidDevBuff.copyFromHostBuffer(mainColValid, 0, validLen); } @@ -4069,7 +4075,7 @@ private static NestedColumnVector createNestedColumnVector(DType type, long rows data.copyFromHostBuffer(dataBuffer, 0, dataLen); } if (validityBuffer != null) { - long validLen = getNativeValidPointerSize((int)rows); + long validLen = getValidityBufferSize((int)rows); valid = DeviceMemoryBuffer.allocate(validLen); valid.copyFromHostBuffer(validityBuffer, 0, validLen); } diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 3bb3eb77984..e45c76afd8d 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -199,7 +199,7 @@ public ColumnVector copyToDevice() { } HostMemoryBuffer hvalid = this.offHeap.valid; if (hvalid != null) { - long validLen = ColumnView.getNativeValidPointerSize((int) rows); + long validLen = ColumnView.getValidityBufferSize((int) rows); valid = DeviceMemoryBuffer.allocate(validLen); valid.copyFromHostBuffer(hvalid, 0, validLen); } @@ -870,11 +870,12 @@ public static final class ColumnBuilder implements AutoCloseable { private long rows; private long estimatedRows; private long rowCapacity = 0L; + private long validCapacity = 0L; private boolean built = false; private List childBuilders = new ArrayList<>(); private int currentIndex = 0; - private int currentByteIndex = 0; + private long currentByteIndex = 0; public ColumnBuilder(HostColumnVector.DataType type, long estimatedRows) { this.type = type.getType(); @@ -891,6 +892,10 @@ public HostColumnVector build() { for (ColumnBuilder childBuilder : childBuilders) { hostColumnVectorCoreList.add(childBuilder.buildNestedInternal()); } + // Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily. + if (valid != null) { + growValidBuffer(); + } HostColumnVector hostColumnVector = new HostColumnVector(type, rows, Optional.of(nullCount), data, valid, offsets, hostColumnVectorCoreList); built = true; @@ -902,6 +907,10 @@ private HostColumnVectorCore buildNestedInternal() { for (ColumnBuilder childBuilder : childBuilders) { hostColumnVectorCoreList.add(childBuilder.buildNestedInternal()); } + // Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily. + if (valid != null) { + growValidBuffer(); + } return new HostColumnVectorCore(type, rows, Optional.of(nullCount), data, valid, offsets, hostColumnVectorCoreList); } @@ -935,14 +944,19 @@ public ColumnBuilder appendStructValues(StructData... inputList) { * this method. */ private void growValidBuffer() { - long maskBytes = byteSizeOfNullMask((int) rowCapacity); if (valid == null) { + long maskBytes = ColumnView.getValidityBufferSize((int) rowCapacity); valid = HostMemoryBuffer.allocate(maskBytes); valid.setMemory(0, valid.length, (byte) 0xFF); - } else { + validCapacity = rowCapacity; + return; + } + if (validCapacity < rowCapacity) { + long maskBytes = ColumnView.getValidityBufferSize((int) rowCapacity); HostMemoryBuffer newValid = HostMemoryBuffer.allocate(maskBytes); newValid.setMemory(0, newValid.length, (byte) 0xFF); valid = copyBuffer(newValid, valid); + validCapacity = rowCapacity; } } @@ -1040,14 +1054,9 @@ private void growStructBuffersAndRows() { */ private static long byteSizeOfNullMask(int numRows) { // number of bytes required = Math.ceil(number of bits / 8) - int actualBytes = (numRows >> 3) + Math.min(numRows % 8, 1); - + int actualBytes = (numRows + 7) >> 3; // padding to multiply of 64 bytes, just as cuDF default padding boundary - long padding = actualBytes >> 6; - if (actualBytes % 64 > 0) padding++; - padding = padding << 6; - - return padding; + return ((actualBytes + 63) >> 6) << 6; } private HostMemoryBuffer copyBuffer(HostMemoryBuffer targetBuffer, HostMemoryBuffer buffer) { @@ -1092,10 +1101,10 @@ public final ColumnBuilder appendNull() { currentByteIndex += type.getSizeInBytes(); if (type.hasOffsets()) { if (type.equals(DType.LIST)) { - offsets.setInt(currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); + offsets.setInt((long) currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); } else { // It is a String - offsets.setInt(currentIndex * OFFSET_SIZE, currentByteIndex); + offsets.setInt((long) currentIndex * OFFSET_SIZE, (int) currentByteIndex); } } else if (type.equals(DType.STRUCT)) { // structs propagate nulls to children and even further down if needed @@ -1221,7 +1230,7 @@ public int getCurrentIndex() { return currentIndex; } - public int getCurrentByteIndex() { + public long getCurrentByteIndex() { return currentByteIndex; } @@ -1229,7 +1238,7 @@ public final ColumnBuilder append(byte value) { growFixedWidthBuffersAndRows(); assert type.isBackedByByte(); assert currentIndex < rows; - data.setByte(currentIndex * type.getSizeInBytes(), value); + data.setByte(currentByteIndex, value); currentIndex++; currentByteIndex += type.getSizeInBytes(); return this; @@ -1239,7 +1248,7 @@ public final ColumnBuilder append(short value) { growFixedWidthBuffersAndRows(); assert type.isBackedByShort(); assert currentIndex < rows; - data.setShort(currentIndex * type.getSizeInBytes(), value); + data.setShort(currentByteIndex, value); currentIndex++; currentByteIndex += type.getSizeInBytes(); return this; @@ -1249,7 +1258,7 @@ public final ColumnBuilder append(int value) { growFixedWidthBuffersAndRows(); assert type.isBackedByInt(); assert currentIndex < rows; - data.setInt(currentIndex * type.getSizeInBytes(), value); + data.setInt(currentByteIndex, value); currentIndex++; currentByteIndex += type.getSizeInBytes(); return this; @@ -1259,7 +1268,7 @@ public final ColumnBuilder append(long value) { growFixedWidthBuffersAndRows(); assert type.isBackedByLong(); assert currentIndex < rows; - data.setLong(currentIndex * type.getSizeInBytes(), value); + data.setLong(currentByteIndex, value); currentIndex++; currentByteIndex += type.getSizeInBytes(); return this; @@ -1269,7 +1278,7 @@ public final ColumnBuilder append(float value) { growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT32); assert currentIndex < rows; - data.setFloat(currentIndex * type.getSizeInBytes(), value); + data.setFloat(currentByteIndex, value); currentIndex++; currentByteIndex += type.getSizeInBytes(); return this; @@ -1279,7 +1288,7 @@ public final ColumnBuilder append(double value) { growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT64); assert currentIndex < rows; - data.setDouble(currentIndex * type.getSizeInBytes(), value); + data.setDouble(currentByteIndex, value); currentIndex++; currentByteIndex += type.getSizeInBytes(); return this; @@ -1289,7 +1298,7 @@ public final ColumnBuilder append(boolean value) { growFixedWidthBuffersAndRows(); assert type.equals(DType.BOOL8); assert currentIndex < rows; - data.setBoolean(currentIndex * type.getSizeInBytes(), value); + data.setBoolean(currentByteIndex, value); currentIndex++; currentByteIndex += type.getSizeInBytes(); return this; @@ -1301,14 +1310,14 @@ public final ColumnBuilder append(BigDecimal value) { // Rescale input decimal with UNNECESSARY policy, which accepts no precision loss. BigInteger unscaledVal = value.setScale(-type.getScale(), RoundingMode.UNNECESSARY).unscaledValue(); if (type.typeId == DType.DTypeEnum.DECIMAL32) { - data.setInt(currentIndex * type.getSizeInBytes(), unscaledVal.intValueExact()); + data.setInt(currentByteIndex, unscaledVal.intValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { - data.setLong(currentIndex * type.getSizeInBytes(), unscaledVal.longValueExact()); + data.setLong(currentByteIndex, unscaledVal.longValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL128) { assert currentIndex < rows; byte[] unscaledValueBytes = value.unscaledValue().toByteArray(); byte[] result = convertDecimal128FromJavaToCudf(unscaledValueBytes); - data.setBytes(currentIndex*DType.DTypeEnum.DECIMAL128.sizeInBytes, result, 0, result.length); + data.setBytes(currentByteIndex, result, 0, result.length); } else { throw new IllegalStateException(type + " is not a supported decimal type."); } @@ -1339,7 +1348,7 @@ public ColumnBuilder appendUTF8String(byte[] value, int srcOffset, int length) { data.setBytes(currentByteIndex, value, srcOffset, length); } currentByteIndex += length; - offsets.setInt(currentIndex * OFFSET_SIZE, currentByteIndex); + offsets.setInt((long) currentIndex * OFFSET_SIZE, (int) currentByteIndex); return this; } @@ -1883,7 +1892,7 @@ public final Builder append(HostColumnVector columnVector) { } private void allocateBitmaskAndSetDefaultValues() { - long bitmaskSize = ColumnView.getNativeValidPointerSize((int) rows); + long bitmaskSize = ColumnView.getValidityBufferSize((int) rows); valid = HostMemoryBuffer.allocate(bitmaskSize); valid.setMemory(0, bitmaskSize, (byte) 0xFF); } diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 73ea49c18d9..828201ebbe2 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1937,16 +1937,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeValidityLength(J CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeValidPointerSize(JNIEnv *env, - jobject j_object, - jint size) { - try { - cudf::jni::auto_set_device(env); - return static_cast(cudf::bitmask_allocation_size_bytes(size)); - } - CATCH_STD(env, 0); -} - JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getDeviceMemorySize(JNIEnv *env, jclass, jlong handle) { JNI_NULL_CHECK(env, handle, "native handle is null", 0); From 73a3a62bf43310ffc4bf28fe43b495e0531cc880 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 14 Jan 2022 11:36:20 +0800 Subject: [PATCH 4/9] small fix --- .../main/java/ai/rapids/cudf/HostColumnVector.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index e45c76afd8d..3f10bd7890c 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -1008,8 +1008,8 @@ private void growStringBuffersAndRows(int stringLength) { rows++; if (offsets == null) { - // Initialize data buffer with at least 64 bytes to avoid growing too frequently. - data = HostMemoryBuffer.allocate(Math.max(64, stringLength)); + // Initialize data buffer with at least 1 byte in case the first appended value is null. + data = HostMemoryBuffer.allocate(Math.max(1, stringLength)); offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE); offsets.setInt(0, 0); rowCapacity = estimatedRows; @@ -1023,11 +1023,11 @@ private void growStringBuffersAndRows(int stringLength) { } long currentLength = currentByteIndex + stringLength; - long requiredLength = data.length; - while (currentLength > requiredLength) { - requiredLength = requiredLength * 2; - } - if (requiredLength > data.length) { + if (currentLength > data.length) { + long requiredLength = data.length; + do { + requiredLength = requiredLength * 2; + } while (currentLength > requiredLength); data = copyBuffer(HostMemoryBuffer.allocate(requiredLength), data); } } From 767b237a2f326a6b984c284aa93b6c34abeac076 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 14 Jan 2022 18:05:09 +0800 Subject: [PATCH 5/9] update --- .../java/ai/rapids/cudf/HostColumnVector.java | 83 ++++++++++--------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 3f10bd7890c..78f971d7e50 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -873,6 +873,7 @@ public static final class ColumnBuilder implements AutoCloseable { private long validCapacity = 0L; private boolean built = false; private List childBuilders = new ArrayList<>(); + private Runnable nullHandler; private int currentIndex = 0; private long currentByteIndex = 0; @@ -882,11 +883,52 @@ public ColumnBuilder(HostColumnVector.DataType type, long estimatedRows) { this.nullable = type.isNullable(); this.rows = 0; this.estimatedRows = estimatedRows; + + // initialize the null handler according to the data type + this.setupNullHandler(); + for (int i = 0; i < type.getNumChildren(); i++) { childBuilders.add(new ColumnBuilder(type.getChild(i), estimatedRows)); } } + private void setupNullHandler() { + if (this.type == DType.LIST) { + this.nullHandler = () -> { + this.growListBuffersAndRows(); + this.growValidBuffer(); + setNullAt(currentIndex++); + currentByteIndex += this.type.getSizeInBytes(); + offsets.setInt((long) currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); + }; + } else if (this.type == DType.STRING) { + this.nullHandler = () -> { + this.growStringBuffersAndRows(0); + this.growValidBuffer(); + setNullAt(currentIndex++); + currentByteIndex += this.type.getSizeInBytes(); + offsets.setInt((long) currentIndex * OFFSET_SIZE, (int) currentByteIndex); + }; + } else if (this.type == DType.STRUCT) { + this.nullHandler = () -> { + this.growStructBuffersAndRows(); + this.growValidBuffer(); + setNullAt(currentIndex++); + currentByteIndex += this.type.getSizeInBytes(); + for (ColumnBuilder childBuilder : childBuilders) { + childBuilder.appendNull(); + } + }; + } else { + this.nullHandler = () -> { + this.growFixedWidthBuffersAndRows(); + this.growValidBuffer(); + setNullAt(currentIndex++); + currentByteIndex += this.type.getSizeInBytes(); + }; + } + } + public HostColumnVector build() { List hostColumnVectorCoreList = new ArrayList<>(); for (ColumnBuilder childBuilder : childBuilders) { @@ -1048,17 +1090,6 @@ private void growStructBuffersAndRows() { } } - /** - * The Java substitution of native method `ColumnView.getNativeValidPointerSize`. - * Ideally, this method can speed up growValidBuffer by eliminating the JNI call. - */ - private static long byteSizeOfNullMask(int numRows) { - // number of bytes required = Math.ceil(number of bits / 8) - int actualBytes = (numRows + 7) >> 3; - // padding to multiply of 64 bytes, just as cuDF default padding boundary - return ((actualBytes + 63) >> 6) << 6; - } - private HostMemoryBuffer copyBuffer(HostMemoryBuffer targetBuffer, HostMemoryBuffer buffer) { try { targetBuffer.copyFromHostBuffer(0, buffer, 0, buffer.length); @@ -1083,35 +1114,7 @@ private void setNullAt(int index) { } public final ColumnBuilder appendNull() { - // Increments row number. And update offsets and data buffer along with the valid buffer. - // NOTE: The growth of valid buffer must happen after rowCapacity updated. - if (type.typeId == DType.DTypeEnum.LIST) { - growListBuffersAndRows(); - } else if (type.typeId == DType.DTypeEnum.STRING) { - growStringBuffersAndRows(0); - } else if (type.getSizeInBytes() > 0) { - growFixedWidthBuffersAndRows(); - } else { - growStructBuffersAndRows(); - } - growValidBuffer(); - - setNullAt(currentIndex); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); - if (type.hasOffsets()) { - if (type.equals(DType.LIST)) { - offsets.setInt((long) currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); - } else { - // It is a String - offsets.setInt((long) currentIndex * OFFSET_SIZE, (int) currentByteIndex); - } - } else if (type.equals(DType.STRUCT)) { - // structs propagate nulls to children and even further down if needed - for (ColumnBuilder childBuilder : childBuilders) { - childBuilder.appendNull(); - } - } + nullHandler.run(); return this; } From c6e9d86193f0ed30c951feaa4d8da810564473a0 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Mon, 17 Jan 2022 16:25:48 +0800 Subject: [PATCH 6/9] update --- .../main/java/ai/rapids/cudf/ColumnView.java | 4 +- .../java/ai/rapids/cudf/HostColumnVector.java | 95 ++++++++----------- 2 files changed, 42 insertions(+), 57 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index b6c5aa86e90..afc53fdfbca 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3274,7 +3274,7 @@ public final Scalar getScalarElement(int index) { * Get the number of bytes needed to allocate a validity buffer for the given number of rows. * According to cudf::bitmask_allocation_size_bytes, the padding boundary for null mask is 64 bytes. */ - public static long getValidityBufferSize(int numRows) { + static long getValidityBufferSize(int numRows) { // number of bytes required = Math.ceil(number of bits / 8) long actualBytes = ((long) numRows + 7) >> 3; // padding to the multiplies of the padding boundary(64 bytes) @@ -3697,7 +3697,7 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat * Native method to find the first (or last) index of each search key in the specified column, * in each row of a list column. * @param nativeView the column view handle of the list - * @param scalarColumnHandle handle to the search key column + * @param keyColumnHandle handle to the search key column * @param isFindFirst Whether to find the first index of the key, or the last. * @return column handle of the resultant column of int32 indices */ diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 78f971d7e50..13434c38ebf 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -875,14 +875,21 @@ public static final class ColumnBuilder implements AutoCloseable { private List childBuilders = new ArrayList<>(); private Runnable nullHandler; - private int currentIndex = 0; - private long currentByteIndex = 0; + // The value of currentIndex can't exceed Int32.Max. Storing currentIndex as a long is to + // adapt HostMemoryBuffer.setXXX, which requires a long offset. + private long currentIndex = 0; + // Only for Strings: pointer of the byte (data) buffer + private int currentStringByteIndex = 0; + // Use bit shift instead of multiply to transform row offset to byte offset + private int bitShiftBySize = 0; + private static final int bitShiftByOffset = (int)(Math.log(OFFSET_SIZE) / Math.log(2)); public ColumnBuilder(HostColumnVector.DataType type, long estimatedRows) { this.type = type.getType(); this.nullable = type.isNullable(); this.rows = 0; - this.estimatedRows = estimatedRows; + this.estimatedRows = Math.max(estimatedRows, 1L); + this.bitShiftBySize = (int)(Math.log(this.type.getSizeInBytes()) / Math.log(2)); // initialize the null handler according to the data type this.setupNullHandler(); @@ -898,23 +905,20 @@ private void setupNullHandler() { this.growListBuffersAndRows(); this.growValidBuffer(); setNullAt(currentIndex++); - currentByteIndex += this.type.getSizeInBytes(); - offsets.setInt((long) currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); + offsets.setInt(currentIndex << bitShiftByOffset, childBuilders.get(0).getCurrentIndex()); }; } else if (this.type == DType.STRING) { this.nullHandler = () -> { this.growStringBuffersAndRows(0); this.growValidBuffer(); setNullAt(currentIndex++); - currentByteIndex += this.type.getSizeInBytes(); - offsets.setInt((long) currentIndex * OFFSET_SIZE, (int) currentByteIndex); + offsets.setInt(currentIndex << bitShiftByOffset, currentStringByteIndex); }; } else if (this.type == DType.STRUCT) { this.nullHandler = () -> { this.growStructBuffersAndRows(); this.growValidBuffer(); setNullAt(currentIndex++); - currentByteIndex += this.type.getSizeInBytes(); for (ColumnBuilder childBuilder : childBuilders) { childBuilder.appendNull(); } @@ -924,7 +928,6 @@ private void setupNullHandler() { this.growFixedWidthBuffersAndRows(); this.growValidBuffer(); setNullAt(currentIndex++); - currentByteIndex += this.type.getSizeInBytes(); }; } } @@ -1011,11 +1014,11 @@ private void growFixedWidthBuffersAndRows() { rows++; if (data == null) { - data = HostMemoryBuffer.allocate(estimatedRows * type.getSizeInBytes()); + data = HostMemoryBuffer.allocate(estimatedRows << bitShiftBySize); rowCapacity = estimatedRows; } else if (rows > rowCapacity) { long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 1); - data = copyBuffer(HostMemoryBuffer.allocate(newCap * type.getSizeInBytes()), data); + data = copyBuffer(HostMemoryBuffer.allocate(newCap << bitShiftBySize), data); rowCapacity = newCap; } } @@ -1029,12 +1032,12 @@ private void growListBuffersAndRows() { rows++; if (offsets == null) { - offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE); + offsets = HostMemoryBuffer.allocate((estimatedRows + 1) << bitShiftByOffset); offsets.setInt(0, 0); rowCapacity = estimatedRows; } else if (rows > rowCapacity) { long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); - offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets); + offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) << bitShiftByOffset), offsets); rowCapacity = newCap; } } @@ -1052,7 +1055,7 @@ private void growStringBuffersAndRows(int stringLength) { if (offsets == null) { // Initialize data buffer with at least 1 byte in case the first appended value is null. data = HostMemoryBuffer.allocate(Math.max(1, stringLength)); - offsets = HostMemoryBuffer.allocate((estimatedRows + 1) * OFFSET_SIZE); + offsets = HostMemoryBuffer.allocate((estimatedRows + 1) << bitShiftByOffset); offsets.setInt(0, 0); rowCapacity = estimatedRows; return; @@ -1060,11 +1063,11 @@ private void growStringBuffersAndRows(int stringLength) { if (rows > rowCapacity) { long newCap = Math.min(rowCapacity * 2, Integer.MAX_VALUE - 2); - offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) * OFFSET_SIZE), offsets); + offsets = copyBuffer(HostMemoryBuffer.allocate((newCap + 1) << bitShiftByOffset), offsets); rowCapacity = newCap; } - long currentLength = currentByteIndex + stringLength; + long currentLength = currentStringByteIndex + stringLength; if (currentLength > data.length) { long requiredLength = data.length; do { @@ -1108,7 +1111,7 @@ private HostMemoryBuffer copyBuffer(HostMemoryBuffer targetBuffer, HostMemoryBuf * Method that sets the null bit in the validity vector * @param index the row index at which the null is marked */ - private void setNullAt(int index) { + private void setNullAt(long index) { assert index < rows : "Index for null value should fit the column with " + rows + " rows"; nullCount += BitVectorHelper.setNullAt(valid, index); } @@ -1173,8 +1176,7 @@ assert allChildrenHaveSameIndex() : "Appending structs data appears to be off " public ColumnBuilder endList() { assert type.equals(DType.LIST); growListBuffersAndRows(); - currentIndex++; - offsets.setInt(currentIndex * OFFSET_SIZE, childBuilders.get(0).getCurrentIndex()); + offsets.setInt(++currentIndex << bitShiftByOffset, childBuilders.get(0).getCurrentIndex()); return this; } @@ -1230,20 +1232,19 @@ public void incrCurrentIndex() { } public int getCurrentIndex() { - return currentIndex; + return (int) currentIndex; } - public long getCurrentByteIndex() { - return currentByteIndex; + @Deprecated + public int getCurrentByteIndex() { + return currentStringByteIndex; } public final ColumnBuilder append(byte value) { growFixedWidthBuffersAndRows(); assert type.isBackedByByte(); assert currentIndex < rows; - data.setByte(currentByteIndex, value); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); + data.setByte(currentIndex++ << bitShiftBySize, value); return this; } @@ -1251,9 +1252,7 @@ public final ColumnBuilder append(short value) { growFixedWidthBuffersAndRows(); assert type.isBackedByShort(); assert currentIndex < rows; - data.setShort(currentByteIndex, value); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); + data.setShort(currentIndex++ << bitShiftBySize, value); return this; } @@ -1261,9 +1260,7 @@ public final ColumnBuilder append(int value) { growFixedWidthBuffersAndRows(); assert type.isBackedByInt(); assert currentIndex < rows; - data.setInt(currentByteIndex, value); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); + data.setInt(currentIndex++ << bitShiftBySize, value); return this; } @@ -1271,9 +1268,7 @@ public final ColumnBuilder append(long value) { growFixedWidthBuffersAndRows(); assert type.isBackedByLong(); assert currentIndex < rows; - data.setLong(currentByteIndex, value); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); + data.setLong(currentIndex++ << bitShiftBySize, value); return this; } @@ -1281,9 +1276,7 @@ public final ColumnBuilder append(float value) { growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT32); assert currentIndex < rows; - data.setFloat(currentByteIndex, value); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); + data.setFloat(currentIndex++ << bitShiftBySize, value); return this; } @@ -1291,9 +1284,7 @@ public final ColumnBuilder append(double value) { growFixedWidthBuffersAndRows(); assert type.equals(DType.FLOAT64); assert currentIndex < rows; - data.setDouble(currentByteIndex, value); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); + data.setDouble(currentIndex++ << bitShiftBySize, value); return this; } @@ -1301,9 +1292,7 @@ public final ColumnBuilder append(boolean value) { growFixedWidthBuffersAndRows(); assert type.equals(DType.BOOL8); assert currentIndex < rows; - data.setBoolean(currentByteIndex, value); - currentIndex++; - currentByteIndex += type.getSizeInBytes(); + data.setBoolean(currentIndex++ << bitShiftBySize, value); return this; } @@ -1313,19 +1302,16 @@ public final ColumnBuilder append(BigDecimal value) { // Rescale input decimal with UNNECESSARY policy, which accepts no precision loss. BigInteger unscaledVal = value.setScale(-type.getScale(), RoundingMode.UNNECESSARY).unscaledValue(); if (type.typeId == DType.DTypeEnum.DECIMAL32) { - data.setInt(currentByteIndex, unscaledVal.intValueExact()); + data.setInt(currentIndex++ << 2, unscaledVal.intValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { - data.setLong(currentByteIndex, unscaledVal.longValueExact()); + data.setLong(currentIndex++ << 3, unscaledVal.longValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL128) { - assert currentIndex < rows; byte[] unscaledValueBytes = value.unscaledValue().toByteArray(); byte[] result = convertDecimal128FromJavaToCudf(unscaledValueBytes); - data.setBytes(currentByteIndex, result, 0, result.length); - } else { + data.setBytes(currentIndex++ << 4, result, 0, result.length); + } else { throw new IllegalStateException(type + " is not a supported decimal type."); } - currentIndex++; - currentByteIndex += type.getSizeInBytes(); return this; } @@ -1344,14 +1330,13 @@ public ColumnBuilder appendUTF8String(byte[] value, int srcOffset, int length) { assert length >= 0; assert value.length + srcOffset <= length; assert type.equals(DType.STRING) : " type " + type + " is not String"; - currentIndex++; growStringBuffersAndRows(length); - assert currentIndex < rows + 1; + assert currentIndex < rows; if (length > 0) { - data.setBytes(currentByteIndex, value, srcOffset, length); + data.setBytes(currentStringByteIndex, value, srcOffset, length); } - currentByteIndex += length; - offsets.setInt((long) currentIndex * OFFSET_SIZE, (int) currentByteIndex); + currentStringByteIndex += length; + offsets.setInt(++currentIndex << bitShiftByOffset, currentStringByteIndex); return this; } From 74f798ce1abea1e4098077d6895f1561c5d19343 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Mon, 17 Jan 2022 16:32:24 +0800 Subject: [PATCH 7/9] small fix --- java/src/main/java/ai/rapids/cudf/HostColumnVector.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 13434c38ebf..69371aa62c0 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -1302,13 +1302,13 @@ public final ColumnBuilder append(BigDecimal value) { // Rescale input decimal with UNNECESSARY policy, which accepts no precision loss. BigInteger unscaledVal = value.setScale(-type.getScale(), RoundingMode.UNNECESSARY).unscaledValue(); if (type.typeId == DType.DTypeEnum.DECIMAL32) { - data.setInt(currentIndex++ << 2, unscaledVal.intValueExact()); + data.setInt(currentIndex++ << bitShiftBySize, unscaledVal.intValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { - data.setLong(currentIndex++ << 3, unscaledVal.longValueExact()); + data.setLong(currentIndex++ << bitShiftBySize, unscaledVal.longValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL128) { byte[] unscaledValueBytes = value.unscaledValue().toByteArray(); byte[] result = convertDecimal128FromJavaToCudf(unscaledValueBytes); - data.setBytes(currentIndex++ << 4, result, 0, result.length); + data.setBytes(currentIndex++ << bitShiftBySize, result, 0, result.length); } else { throw new IllegalStateException(type + " is not a supported decimal type."); } From 590868526a10e7088d75008b9b91651eaf3d5edc Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Tue, 18 Jan 2022 16:48:15 +0800 Subject: [PATCH 8/9] add more tests for ColumnBuilder Signed-off-by: sperlingxx --- java/src/main/java/ai/rapids/cudf/DType.java | 4 +- .../ai/rapids/cudf/ByteColumnVectorTest.java | 106 ++++++++---- .../ai/rapids/cudf/ColumnBuilderHelper.java | 158 ++++++++++++++++++ .../rapids/cudf/DecimalColumnVectorTest.java | 64 +++++-- .../rapids/cudf/DoubleColumnVectorTest.java | 54 ++++-- .../ai/rapids/cudf/IntColumnVectorTest.java | 82 ++++++--- .../ai/rapids/cudf/LongColumnVectorTest.java | 82 ++++++--- 7 files changed, 443 insertions(+), 107 deletions(-) create mode 100644 java/src/test/java/ai/rapids/cudf/ColumnBuilderHelper.java diff --git a/java/src/main/java/ai/rapids/cudf/DType.java b/java/src/main/java/ai/rapids/cudf/DType.java index 742501be375..2e5b0202dc5 100644 --- a/java/src/main/java/ai/rapids/cudf/DType.java +++ b/java/src/main/java/ai/rapids/cudf/DType.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -307,7 +307,7 @@ public static DType fromJavaBigDecimal(BigDecimal dec) { return new DType(DTypeEnum.DECIMAL128, -dec.scale()); } throw new IllegalArgumentException("Precision " + dec.precision() + - " exceeds max precision cuDF can support " + DECIMAL64_MAX_PRECISION); + " exceeds max precision cuDF can support " + DECIMAL128_MAX_PRECISION); } /** diff --git a/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java index a26dbec4907..7b476c31b95 100644 --- a/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ByteColumnVectorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import java.util.Random; +import java.util.function.Consumer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -39,21 +40,34 @@ public void testCreateColumnVectorBuilder() { @Test public void testArrayAllocation() { - try (HostColumnVector byteColumnVector = HostColumnVector.fromBytes(new byte[]{2, 3, 5})) { - assertFalse(byteColumnVector.hasNulls()); - assertEquals(byteColumnVector.getByte(0), 2); - assertEquals(byteColumnVector.getByte(1), 3); - assertEquals(byteColumnVector.getByte(2), 5); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertEquals(cv.getByte(0), 2); + assertEquals(cv.getByte(1), 3); + assertEquals(cv.getByte(2), 5); + }; + try (HostColumnVector bcv = HostColumnVector.fromBytes(new byte[]{2, 3, 5})) { + verify.accept(bcv); + } + try (HostColumnVector bcv = ColumnBuilderHelper.fromBytes(true, new byte[]{2, 3, 5})) { + verify.accept(bcv); } } @Test public void testUnsignedArrayAllocation() { - try (HostColumnVector v = HostColumnVector.fromUnsignedBytes(new byte[]{(byte)0xff, (byte)128, 5})) { - assertFalse(v.hasNulls()); - assertEquals(0xff, Byte.toUnsignedInt(v.getByte(0)), 0xff); - assertEquals(128, Byte.toUnsignedInt(v.getByte(1)), 128); - assertEquals(5, Byte.toUnsignedInt(v.getByte(2)), 5); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertEquals(0xff, Byte.toUnsignedInt(cv.getByte(0)), 0xff); + assertEquals(128, Byte.toUnsignedInt(cv.getByte(1)), 128); + assertEquals(5, Byte.toUnsignedInt(cv.getByte(2)), 5); + }; + try (HostColumnVector bcv = HostColumnVector.fromUnsignedBytes(new byte[]{(byte)0xff, (byte)128, 5})) { + verify.accept(bcv); + } + try (HostColumnVector bcv = ColumnBuilderHelper.fromBytes(false, + new byte[]{(byte)0xff, (byte)128, 5})) { + verify.accept(bcv); } } @@ -70,47 +84,73 @@ public void testAppendRepeatingValues() { @Test public void testUpperIndexOutOfBoundsException() { - try (HostColumnVector byteColumnVector = HostColumnVector.fromBytes(new byte[]{2, 3, 5})) { - assertThrows(AssertionError.class, () -> byteColumnVector.getByte(3)); - assertFalse(byteColumnVector.hasNulls()); + Consumer verify = (cv) -> { + assertThrows(AssertionError.class, () -> cv.getByte(3)); + assertFalse(cv.hasNulls()); + }; + try (HostColumnVector bcv = HostColumnVector.fromBytes(new byte[]{2, 3, 5})) { + verify.accept(bcv); + } + try (HostColumnVector bcv = ColumnBuilderHelper.fromBytes(true, new byte[]{2, 3, 5})) { + verify.accept(bcv); } } @Test public void testLowerIndexOutOfBoundsException() { - try (HostColumnVector byteColumnVector = HostColumnVector.fromBytes(new byte[]{2, 3, 5})) { - assertFalse(byteColumnVector.hasNulls()); - assertThrows(AssertionError.class, () -> byteColumnVector.getByte(-1)); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertThrows(AssertionError.class, () -> cv.getByte(-1)); + }; + try (HostColumnVector bcv = HostColumnVector.fromBytes(new byte[]{2, 3, 5})) { + verify.accept(bcv); + } + try (HostColumnVector bcv = ColumnBuilderHelper.fromBytes(true, new byte[]{2, 3, 5})) { + verify.accept(bcv); } } @Test public void testAddingNullValues() { - try (HostColumnVector byteColumnVector = HostColumnVector.fromBoxedBytes( - new Byte[]{2, 3, 4, 5, 6, 7, null, null})) { - assertTrue(byteColumnVector.hasNulls()); - assertEquals(2, byteColumnVector.getNullCount()); + Consumer verify = (cv) -> { + assertTrue(cv.hasNulls()); + assertEquals(2, cv.getNullCount()); for (int i = 0; i < 6; i++) { - assertFalse(byteColumnVector.isNull(i)); + assertFalse(cv.isNull(i)); } - assertTrue(byteColumnVector.isNull(6)); - assertTrue(byteColumnVector.isNull(7)); + assertTrue(cv.isNull(6)); + assertTrue(cv.isNull(7)); + }; + try (HostColumnVector bcv = HostColumnVector.fromBoxedBytes( + new Byte[]{2, 3, 4, 5, 6, 7, null, null})) { + verify.accept(bcv); + } + try (HostColumnVector bcv = ColumnBuilderHelper.fromBoxedBytes(true, + new Byte[]{2, 3, 4, 5, 6, 7, null, null})) { + verify.accept(bcv); } } @Test public void testAddingUnsignedNullValues() { - try (HostColumnVector byteColumnVector = HostColumnVector.fromBoxedUnsignedBytes( - new Byte[]{2, 3, 4, 5, (byte)128, (byte)254, null, null})) { - assertTrue(byteColumnVector.hasNulls()); - assertEquals(2, byteColumnVector.getNullCount()); + Consumer verify = (cv) -> { + assertTrue(cv.hasNulls()); + assertEquals(2, cv.getNullCount()); for (int i = 0; i < 6; i++) { - assertFalse(byteColumnVector.isNull(i)); + assertFalse(cv.isNull(i)); } - assertEquals(128, Byte.toUnsignedInt(byteColumnVector.getByte(4))); - assertEquals(254, Byte.toUnsignedInt(byteColumnVector.getByte(5))); - assertTrue(byteColumnVector.isNull(6)); - assertTrue(byteColumnVector.isNull(7)); + assertEquals(128, Byte.toUnsignedInt(cv.getByte(4))); + assertEquals(254, Byte.toUnsignedInt(cv.getByte(5))); + assertTrue(cv.isNull(6)); + assertTrue(cv.isNull(7)); + }; + try (HostColumnVector bcv = HostColumnVector.fromBoxedUnsignedBytes( + new Byte[]{2, 3, 4, 5, (byte)128, (byte)254, null, null})) { + verify.accept(bcv); + } + try (HostColumnVector bcv = ColumnBuilderHelper.fromBoxedBytes(false, + new Byte[]{2, 3, 4, 5, (byte)128, (byte)254, null, null})) { + verify.accept(bcv); } } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnBuilderHelper.java b/java/src/test/java/ai/rapids/cudf/ColumnBuilderHelper.java new file mode 100644 index 00000000000..263244b2413 --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/ColumnBuilderHelper.java @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Objects; +import java.util.function.Consumer; + +/** + * ColumnBuilderHelper helps to test ColumnBuilder with existed ColumnVector tests. + */ +public class ColumnBuilderHelper { + + public static HostColumnVector build( + HostColumnVector.DataType type, + int rows, + Consumer init) { + try (HostColumnVector.ColumnBuilder b = new HostColumnVector.ColumnBuilder(type, rows)) { + init.accept(b); + return b.build(); + } + } + + public static ColumnVector buildOnDevice( + HostColumnVector.DataType type, + int rows, + Consumer init) { + try (HostColumnVector.ColumnBuilder b = new HostColumnVector.ColumnBuilder(type, rows)) { + init.accept(b); + return b.buildAndPutOnDevice(); + } + } + + public static HostColumnVector fromBoxedBytes(boolean signed, Byte... values) { + DType dt = signed ? DType.INT8 : DType.UINT8; + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(true, dt), + values.length, + (b) -> { + for (Byte v : values) + if (v == null) b.appendNull(); + else b.append(v); + }); + } + + public static HostColumnVector fromBoxedDoubles(Double... values) { + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(true, DType.FLOAT64), + values.length, + (b) -> { + for (Double v : values) + if (v == null) b.appendNull(); + else b.append(v); + }); + } + + public static HostColumnVector fromBoxedInts(boolean signed, Integer... values) { + DType dt = signed ? DType.INT32 : DType.UINT32; + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(true, dt), + values.length, + (b) -> { + for (Integer v : values) + if (v == null) b.appendNull(); + else b.append(v); + }); + } + + public static HostColumnVector fromBoxedLongs(boolean signed, Long... values) { + DType dt = signed ? DType.INT64 : DType.UINT64; + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(true, dt), + values.length, + (b) -> { + for (Long v : values) + if (v == null) b.appendNull(); + else b.append(v); + }); + } + + public static HostColumnVector fromBytes(boolean signed, byte... values) { + DType dt = signed ? DType.INT8 : DType.UINT8; + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(false, dt), + values.length, + (b) -> { + for (byte v : values) b.append(v); + }); + } + + public static HostColumnVector fromDecimals(BigDecimal... values) { + // Simply copy from HostColumnVector.fromDecimals + BigDecimal maxDec = Arrays.stream(values).filter(Objects::nonNull) + .max(Comparator.comparingInt(BigDecimal::precision)) + .orElse(BigDecimal.ZERO); + int maxScale = Arrays.stream(values).filter(Objects::nonNull) + .map(decimal -> decimal.scale()) + .max(Comparator.naturalOrder()) + .orElse(0); + maxDec = maxDec.setScale(maxScale, RoundingMode.UNNECESSARY); + + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(true, DType.fromJavaBigDecimal(maxDec)), + values.length, + (b) -> { + for (BigDecimal v : values) + if (v == null) b.appendNull(); + else b.append(v); + }); + } + + public static HostColumnVector fromDoubles(double... values) { + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(false, DType.FLOAT64), + values.length, + (b) -> { + for (double v : values) b.append(v); + }); + } + + public static HostColumnVector fromInts(boolean signed, int... values) { + DType dt = signed ? DType.INT32 : DType.UINT32; + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(false, dt), + values.length, + (b) -> { + for (int v : values) b.append(v); + }); + } + + public static HostColumnVector fromLongs(boolean signed, long... values) { + DType dt = signed ? DType.INT64 : DType.UINT64; + return ColumnBuilderHelper.build( + new HostColumnVector.BasicType(false, dt), + values.length, + (b) -> { + for (long v : values) b.append(v); + }); + } +} diff --git a/java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java index c2772520f57..994066c5df0 100644 --- a/java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,10 +22,12 @@ import org.junit.jupiter.api.Test; import java.math.BigDecimal; +import java.math.BigInteger; import java.math.RoundingMode; import java.util.Arrays; import java.util.Objects; import java.util.Random; +import java.util.function.Consumer; import static org.junit.jupiter.api.Assertions.*; @@ -33,9 +35,11 @@ public class DecimalColumnVectorTest extends CudfTestBase { private static final Random rdSeed = new Random(1234); private static final int dec32Scale = 4; private static final int dec64Scale = 10; + private static final int dec128Scale = 30; private static final BigDecimal[] decimal32Zoo = new BigDecimal[20]; private static final BigDecimal[] decimal64Zoo = new BigDecimal[20]; + private static final BigDecimal[] decimal128Zoo = new BigDecimal[20]; private static final int[] unscaledDec32Zoo = new int[decimal32Zoo.length]; private static final long[] unscaledDec64Zoo = new long[decimal64Zoo.length]; @@ -45,6 +49,9 @@ public class DecimalColumnVectorTest extends CudfTestBase { private final BigDecimal[] boundaryDecimal64 = new BigDecimal[]{ new BigDecimal("999999999999999999"), new BigDecimal("-999999999999999999")}; + private final BigDecimal[] boundaryDecimal128 = new BigDecimal[]{ + new BigDecimal("99999999999999999999999999999999999999"), new BigDecimal("-99999999999999999999999999999999999999")}; + private final BigDecimal[] overflowDecimal32 = new BigDecimal[]{ BigDecimal.valueOf(Integer.MAX_VALUE), BigDecimal.valueOf(Integer.MIN_VALUE)}; @@ -72,6 +79,12 @@ public static void setup() { } else { decimal64Zoo[i] = null; } + if (rdSeed.nextBoolean()) { + BigInteger unscaledVal = BigInteger.valueOf(rdSeed.nextLong()).multiply(BigInteger.valueOf(rdSeed.nextLong())); + decimal128Zoo[i] = new BigDecimal(unscaledVal, dec128Scale); + } else { + decimal128Zoo[i] = null; + } } } @@ -190,27 +203,44 @@ public void testDecimalGeneral() { @Test public void testDecimalFromDecimals() { - DecimalColumnVectorTest.testDecimalImpl(false, dec32Scale, decimal32Zoo); - DecimalColumnVectorTest.testDecimalImpl(true, dec64Scale, decimal64Zoo); - DecimalColumnVectorTest.testDecimalImpl(false, 0, boundaryDecimal32); - DecimalColumnVectorTest.testDecimalImpl(true, 0, boundaryDecimal64); + DecimalColumnVectorTest.testDecimalImpl(DType.DTypeEnum.DECIMAL32, dec32Scale, decimal32Zoo); + DecimalColumnVectorTest.testDecimalImpl(DType.DTypeEnum.DECIMAL64, dec64Scale, decimal64Zoo); + DecimalColumnVectorTest.testDecimalImpl(DType.DTypeEnum.DECIMAL128, dec128Scale, decimal128Zoo); + DecimalColumnVectorTest.testDecimalImpl(DType.DTypeEnum.DECIMAL32, 0, boundaryDecimal32); + DecimalColumnVectorTest.testDecimalImpl(DType.DTypeEnum.DECIMAL64, 0, boundaryDecimal64); + DecimalColumnVectorTest.testDecimalImpl(DType.DTypeEnum.DECIMAL128, 0, boundaryDecimal128); } - private static void testDecimalImpl(boolean isInt64, int scale, BigDecimal[] decimalZoo) { - try (ColumnVector cv = ColumnVector.fromDecimals(decimalZoo)) { - try (HostColumnVector hcv = cv.copyToHost()) { - assertEquals(-scale, hcv.getType().getScale()); - assertEquals(isInt64, hcv.getType().typeId == DType.DTypeEnum.DECIMAL64); - assertEquals(decimalZoo.length, hcv.rows); - for (int i = 0; i < decimalZoo.length; i++) { - assertEquals(decimalZoo[i] == null, hcv.isNull(i)); - if (decimalZoo[i] != null) { - assertEquals(decimalZoo[i].floatValue(), hcv.getBigDecimal(i).floatValue()); - long backValue = isInt64 ? hcv.getLong(i) : hcv.getInt(i); - assertEquals(decimalZoo[i].setScale(scale, RoundingMode.UNNECESSARY), BigDecimal.valueOf(backValue, scale)); + private static void testDecimalImpl(DType.DTypeEnum decimalType, int scale, BigDecimal[] decimalZoo) { + Consumer assertions = (hcv) -> { + assertEquals(-scale, hcv.getType().getScale()); + assertEquals(hcv.getType().typeId, decimalType); + assertEquals(decimalZoo.length, hcv.rows); + for (int i = 0; i < decimalZoo.length; i++) { + assertEquals(decimalZoo[i] == null, hcv.isNull(i)); + if (decimalZoo[i] != null) { + BigDecimal actual; + switch (decimalType) { + case DECIMAL32: + actual = BigDecimal.valueOf(hcv.getInt(i), scale); + break; + case DECIMAL64: + actual = BigDecimal.valueOf(hcv.getLong(i), scale); + break; + default: + actual = hcv.getBigDecimal(i); } + assertEquals(decimalZoo[i].subtract(actual).longValueExact(), 0L); } } + }; + try (ColumnVector cv = ColumnVector.fromDecimals(decimalZoo)) { + try (HostColumnVector hcv = cv.copyToHost()) { + assertions.accept(hcv); + } + } + try (HostColumnVector hcv = ColumnBuilderHelper.fromDecimals(decimalZoo)) { + assertions.accept(hcv); } } diff --git a/java/src/test/java/ai/rapids/cudf/DoubleColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/DoubleColumnVectorTest.java index d82565e1d2d..fa34429685e 100644 --- a/java/src/test/java/ai/rapids/cudf/DoubleColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/DoubleColumnVectorTest.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import java.util.Random; +import java.util.function.Consumer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -40,34 +41,51 @@ public void testCreateColumnVectorBuilder() { @Test public void testArrayAllocation() { - try (HostColumnVector doubleColumnVector = HostColumnVector.fromDoubles(2.1, 3.02, 5.003)) { - assertFalse(doubleColumnVector.hasNulls()); - assertEqualsWithinPercentage(doubleColumnVector.getDouble(0), 2.1, 0.01); - assertEqualsWithinPercentage(doubleColumnVector.getDouble(1), 3.02, 0.01); - assertEqualsWithinPercentage(doubleColumnVector.getDouble(2), 5.003, 0.001); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertEqualsWithinPercentage(cv.getDouble(0), 2.1, 0.01); + assertEqualsWithinPercentage(cv.getDouble(1), 3.02, 0.01); + assertEqualsWithinPercentage(cv.getDouble(2), 5.003, 0.001); + }; + try (HostColumnVector dcv = HostColumnVector.fromDoubles(2.1, 3.02, 5.003)) { + verify.accept(dcv); + } + try (HostColumnVector dcv = ColumnBuilderHelper.fromDoubles(2.1, 3.02, 5.003)) { + verify.accept(dcv); } } @Test public void testUpperIndexOutOfBoundsException() { - try (HostColumnVector doubleColumnVector = HostColumnVector.fromDoubles(2.1, 3.02, 5.003)) { - assertThrows(AssertionError.class, () -> doubleColumnVector.getDouble(3)); - assertFalse(doubleColumnVector.hasNulls()); + Consumer verify = (cv) -> { + assertThrows(AssertionError.class, () -> cv.getDouble(3)); + assertFalse(cv.hasNulls()); + }; + try (HostColumnVector dcv = HostColumnVector.fromDoubles(2.1, 3.02, 5.003)) { + verify.accept(dcv); + } + try (HostColumnVector dcv = ColumnBuilderHelper.fromDoubles(2.1, 3.02, 5.003)) { + verify.accept(dcv); } } @Test public void testLowerIndexOutOfBoundsException() { - try (HostColumnVector doubleColumnVector = HostColumnVector.fromDoubles(2.1, 3.02, 5.003)) { - assertFalse(doubleColumnVector.hasNulls()); - assertThrows(AssertionError.class, () -> doubleColumnVector.getDouble(-1)); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertThrows(AssertionError.class, () -> cv.getDouble(-1)); + }; + try (HostColumnVector dcv = HostColumnVector.fromDoubles(2.1, 3.02, 5.003)) { + verify.accept(dcv); + } + try (HostColumnVector dcv = ColumnBuilderHelper.fromDoubles(2.1, 3.02, 5.003)) { + verify.accept(dcv); } } @Test public void testAddingNullValues() { - try (HostColumnVector cv = - HostColumnVector.fromBoxedDoubles(2.0, 3.0, 4.0, 5.0, 6.0, 7.0, null, null)) { + Consumer verify = (cv) -> { assertTrue(cv.hasNulls()); assertEquals(2, cv.getNullCount()); for (int i = 0; i < 6; i++) { @@ -75,6 +93,14 @@ public void testAddingNullValues() { } assertTrue(cv.isNull(6)); assertTrue(cv.isNull(7)); + }; + try (HostColumnVector dcv = + HostColumnVector.fromBoxedDoubles(2.0, 3.0, 4.0, 5.0, 6.0, 7.0, null, null)) { + verify.accept(dcv); + } + try (HostColumnVector dcv = ColumnBuilderHelper.fromBoxedDoubles( + 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, null, null)) { + verify.accept(dcv); } } diff --git a/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java index 2fb8164534b..7d6311fb24c 100644 --- a/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/IntColumnVectorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import java.util.Random; +import java.util.function.Consumer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -34,47 +35,75 @@ public void testCreateColumnVectorBuilder() { try (ColumnVector intColumnVector = ColumnVector.build(DType.INT32, 3, (b) -> b.append(1))) { assertFalse(intColumnVector.hasNulls()); } + try (ColumnVector intColumnVector = ColumnBuilderHelper.buildOnDevice( + new HostColumnVector.BasicType(true, DType.INT32), 3, (b) -> b.append(1))) { + assertFalse(intColumnVector.hasNulls()); + } } @Test public void testArrayAllocation() { - try (HostColumnVector intColumnVector = HostColumnVector.fromInts(2, 3, 5)) { - assertFalse(intColumnVector.hasNulls()); - assertEquals(intColumnVector.getInt(0), 2); - assertEquals(intColumnVector.getInt(1), 3); - assertEquals(intColumnVector.getInt(2), 5); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertEquals(cv.getInt(0), 2); + assertEquals(cv.getInt(1), 3); + assertEquals(cv.getInt(2), 5); + }; + try (HostColumnVector cv = HostColumnVector.fromInts(2, 3, 5)) { + verify.accept(cv); + } + try (HostColumnVector cv = ColumnBuilderHelper.fromInts(true, 2, 3, 5)) { + verify.accept(cv); } } @Test public void testUnsignedArrayAllocation() { - try (HostColumnVector v = HostColumnVector.fromUnsignedInts(0xfedcba98, 0x80000000, 5)) { - assertFalse(v.hasNulls()); - assertEquals(0xfedcba98L, Integer.toUnsignedLong(v.getInt(0))); - assertEquals(0x80000000L, Integer.toUnsignedLong(v.getInt(1))); - assertEquals(5, Integer.toUnsignedLong(v.getInt(2))); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertEquals(0xfedcba98L, Integer.toUnsignedLong(cv.getInt(0))); + assertEquals(0x80000000L, Integer.toUnsignedLong(cv.getInt(1))); + assertEquals(5, Integer.toUnsignedLong(cv.getInt(2))); + }; + try (HostColumnVector cv = HostColumnVector.fromUnsignedInts(0xfedcba98, 0x80000000, 5)) { + verify.accept(cv); + } + try (HostColumnVector cv = ColumnBuilderHelper.fromInts(false, 0xfedcba98, 0x80000000, 5)) { + verify.accept(cv); } } @Test public void testUpperIndexOutOfBoundsException() { - try (HostColumnVector intColumnVector = HostColumnVector.fromInts(2, 3, 5)) { - assertThrows(AssertionError.class, () -> intColumnVector.getInt(3)); - assertFalse(intColumnVector.hasNulls()); + Consumer verify = (cv) -> { + assertThrows(AssertionError.class, () -> cv.getInt(3)); + assertFalse(cv.hasNulls()); + }; + try (HostColumnVector icv = HostColumnVector.fromInts(2, 3, 5)) { + verify.accept(icv); + } + try (HostColumnVector icv = ColumnBuilderHelper.fromInts(true, 2, 3, 5)) { + verify.accept(icv); } } @Test public void testLowerIndexOutOfBoundsException() { - try (HostColumnVector intColumnVector = HostColumnVector.fromInts(2, 3, 5)) { - assertFalse(intColumnVector.hasNulls()); - assertThrows(AssertionError.class, () -> intColumnVector.getInt(-1)); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertThrows(AssertionError.class, () -> cv.getInt(-1)); + }; + try (HostColumnVector icv = HostColumnVector.fromInts(2, 3, 5)) { + verify.accept(icv); + } + try (HostColumnVector icv = ColumnBuilderHelper.fromInts(true, 2, 3, 5)) { + verify.accept(icv); } } @Test public void testAddingNullValues() { - try (HostColumnVector cv = HostColumnVector.fromBoxedInts(2, 3, 4, 5, 6, 7, null, null)) { + Consumer verify = (cv) -> { assertTrue(cv.hasNulls()); assertEquals(2, cv.getNullCount()); for (int i = 0; i < 6; i++) { @@ -82,13 +111,18 @@ public void testAddingNullValues() { } assertTrue(cv.isNull(6)); assertTrue(cv.isNull(7)); + }; + try (HostColumnVector cv = HostColumnVector.fromBoxedInts(2, 3, 4, 5, 6, 7, null, null)) { + verify.accept(cv); + } + try (HostColumnVector cv = ColumnBuilderHelper.fromBoxedInts(true, 2, 3, 4, 5, 6, 7, null, null)) { + verify.accept(cv); } } @Test public void testAddingUnsignedNullValues() { - try (HostColumnVector cv = HostColumnVector.fromBoxedUnsignedInts( - 2, 3, 4, 5, 0xfedbca98, 0x80000000, null, null)) { + Consumer verify = (cv) -> { assertTrue(cv.hasNulls()); assertEquals(2, cv.getNullCount()); for (int i = 0; i < 6; i++) { @@ -98,6 +132,14 @@ public void testAddingUnsignedNullValues() { assertEquals(0x80000000L, Integer.toUnsignedLong(cv.getInt(5))); assertTrue(cv.isNull(6)); assertTrue(cv.isNull(7)); + }; + try (HostColumnVector cv = HostColumnVector.fromBoxedUnsignedInts( + 2, 3, 4, 5, 0xfedbca98, 0x80000000, null, null)) { + verify.accept(cv); + } + try (HostColumnVector cv = ColumnBuilderHelper.fromBoxedInts(false, + 2, 3, 4, 5, 0xfedbca98, 0x80000000, null, null)) { + verify.accept(cv); } } diff --git a/java/src/test/java/ai/rapids/cudf/LongColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/LongColumnVectorTest.java index 43c2b5a99c2..193992f5304 100644 --- a/java/src/test/java/ai/rapids/cudf/LongColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/LongColumnVectorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import java.util.Random; +import java.util.function.Consumer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -38,46 +39,71 @@ public void testCreateColumnVectorBuilder() { @Test public void testArrayAllocation() { - try (HostColumnVector longColumnVector = HostColumnVector.fromLongs(2L, 3L, 5L)) { - assertFalse(longColumnVector.hasNulls()); - assertEquals(longColumnVector.getLong(0), 2); - assertEquals(longColumnVector.getLong(1), 3); - assertEquals(longColumnVector.getLong(2), 5); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertEquals(cv.getLong(0), 2); + assertEquals(cv.getLong(1), 3); + assertEquals(cv.getLong(2), 5); + }; + try (HostColumnVector lcv = HostColumnVector.fromLongs(2L, 3L, 5L)) { + verify.accept(lcv); + } + try (HostColumnVector lcv = ColumnBuilderHelper.fromLongs(true,2L, 3L, 5L)) { + verify.accept(lcv); } } @Test public void testUnsignedArrayAllocation() { - try (HostColumnVector longColumnVector = HostColumnVector.fromUnsignedLongs( - 0xfedcba9876543210L, 0x8000000000000000L, 5L)) { - assertFalse(longColumnVector.hasNulls()); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); assertEquals(Long.toUnsignedString(0xfedcba9876543210L), - Long.toUnsignedString(longColumnVector.getLong(0))); + Long.toUnsignedString(cv.getLong(0))); assertEquals(Long.toUnsignedString(0x8000000000000000L), - Long.toUnsignedString(longColumnVector.getLong(1))); - assertEquals(5L, longColumnVector.getLong(2)); + Long.toUnsignedString(cv.getLong(1))); + assertEquals(5L, cv.getLong(2)); + }; + try (HostColumnVector lcv = HostColumnVector.fromUnsignedLongs( + 0xfedcba9876543210L, 0x8000000000000000L, 5L)) { + verify.accept(lcv); + } + try (HostColumnVector lcv = ColumnBuilderHelper.fromLongs(false, + 0xfedcba9876543210L, 0x8000000000000000L, 5L)) { + verify.accept(lcv); } } @Test public void testUpperIndexOutOfBoundsException() { - try (HostColumnVector longColumnVector = HostColumnVector.fromLongs(2L, 3L, 5L)) { - assertThrows(AssertionError.class, () -> longColumnVector.getLong(3)); - assertFalse(longColumnVector.hasNulls()); + Consumer verify = (cv) -> { + assertThrows(AssertionError.class, () -> cv.getLong(3)); + assertFalse(cv.hasNulls()); + }; + try (HostColumnVector lcv = HostColumnVector.fromLongs(2L, 3L, 5L)) { + verify.accept(lcv); + } + try (HostColumnVector lcv = ColumnBuilderHelper.fromLongs(true, 2L, 3L, 5L)) { + verify.accept(lcv); } } @Test public void testLowerIndexOutOfBoundsException() { - try (HostColumnVector longColumnVector = HostColumnVector.fromLongs(2L, 3L, 5L)) { - assertFalse(longColumnVector.hasNulls()); - assertThrows(AssertionError.class, () -> longColumnVector.getLong(-1)); + Consumer verify = (cv) -> { + assertFalse(cv.hasNulls()); + assertThrows(AssertionError.class, () -> cv.getLong(-1)); + }; + try (HostColumnVector lcv = HostColumnVector.fromLongs(2L, 3L, 5L)) { + verify.accept(lcv); + } + try (HostColumnVector lcv = ColumnBuilderHelper.fromLongs(true, 2L, 3L, 5L)) { + verify.accept(lcv); } } @Test public void testAddingNullValues() { - try (HostColumnVector cv = HostColumnVector.fromBoxedLongs(2L, 3L, 4L, 5L, 6L, 7L, null, null)) { + Consumer verify = (cv) -> { assertTrue(cv.hasNulls()); assertEquals(2, cv.getNullCount()); for (int i = 0; i < 6; i++) { @@ -85,13 +111,19 @@ public void testAddingNullValues() { } assertTrue(cv.isNull(6)); assertTrue(cv.isNull(7)); + }; + try (HostColumnVector lcv = HostColumnVector.fromBoxedLongs(2L, 3L, 4L, 5L, 6L, 7L, null, null)) { + verify.accept(lcv); + } + try (HostColumnVector lcv = ColumnBuilderHelper.fromBoxedLongs(true, + 2L, 3L, 4L, 5L, 6L, 7L, null, null)) { + verify.accept(lcv); } } @Test public void testAddingUnsignedNullValues() { - try (HostColumnVector cv = HostColumnVector.fromBoxedUnsignedLongs( - 2L, 3L, 4L, 5L, 0xfedcba9876543210L, 0x8000000000000000L, null, null)) { + Consumer verify = (cv) -> { assertTrue(cv.hasNulls()); assertEquals(2, cv.getNullCount()); for (int i = 0; i < 6; i++) { @@ -103,6 +135,14 @@ public void testAddingUnsignedNullValues() { Long.toUnsignedString(cv.getLong(5))); assertTrue(cv.isNull(6)); assertTrue(cv.isNull(7)); + }; + try (HostColumnVector lcv = HostColumnVector.fromBoxedUnsignedLongs( + 2L, 3L, 4L, 5L, 0xfedcba9876543210L, 0x8000000000000000L, null, null)) { + verify.accept(lcv); + } + try (HostColumnVector lcv = ColumnBuilderHelper.fromBoxedLongs(false, + 2L, 3L, 4L, 5L, 0xfedcba9876543210L, 0x8000000000000000L, null, null)) { + verify.accept(lcv); } } From 1f672109a0dd0040a25f9eec85e59e7a0168b1fd Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Mon, 24 Jan 2022 18:56:56 +0800 Subject: [PATCH 9/9] merge master --- java/src/main/java/ai/rapids/cudf/HostColumnVector.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 18973b25abe..3abc6db385d 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -1303,14 +1303,13 @@ public ColumnBuilder append(BigDecimal value) { } public ColumnBuilder append(BigInteger unscaledVal) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); + growFixedWidthBuffersAndRows(); assert currentIndex < rows; if (type.typeId == DType.DTypeEnum.DECIMAL32) { data.setInt(currentIndex++ << bitShiftBySize, unscaledVal.intValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { data.setLong(currentIndex++ << bitShiftBySize, unscaledVal.longValueExact()); } else if (type.typeId == DType.DTypeEnum.DECIMAL128) { - assert currentIndex < rows; byte[] unscaledValueBytes = unscaledVal.toByteArray(); byte[] result = convertDecimal128FromJavaToCudf(unscaledValueBytes); data.setBytes(currentIndex++ << bitShiftBySize, result, 0, result.length);