Skip to content

Commit

Permalink
Merge pull request #2568 from rapidsai/branch-0.9
Browse files Browse the repository at this point in the history
[gpuCI] Auto-merge branch-0.9 to branch-0.10 [skip ci]
  • Loading branch information
GPUtester authored Aug 13, 2019
2 parents 34d6523 + f642c2c commit 080c05d
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
- PR #2557 fix cudautils import in string.py
- PR #2521 Fix casting datetimes from/to the same resolution
- PR #2560 Remove duplicate `dlpack` definition in conda recipe
- PR #2567 Fix ColumnVector.fromScalar issues while dealing with null scalars


# cuDF 0.8.0 (27 June 2019)
Expand Down
116 changes: 67 additions & 49 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,8 @@ public ColumnVector[] slice(ColumnVector indices) {
* @param scalar - Scalar value to replace row with
* @throws IllegalArgumentException
*/
private void fill(Scalar scalar) throws IllegalArgumentException {
// package private for testing
void fill(Scalar scalar) throws IllegalArgumentException {
assert scalar.getType() == this.getType();

if (this.getType() == DType.STRING || this.getType() == DType.STRING_CATEGORY){
Expand All @@ -1067,45 +1068,73 @@ private void fill(Scalar scalar) throws IllegalArgumentException {
return; // no rows to fill
}

checkHasDeviceData();

BufferEncapsulator<DeviceMemoryBuffer> newDeviceData = null;
boolean needsCleanup = true;

try {
if (!scalar.isValid() && this.offHeap.getDeviceData().valid == null) {
// scalar is null, we allow filling with nulls
// create validity mask, since we didn't have one before
long validitySizeInBytes = BitVectorHelper.getValidityAllocationSizeInBytes(rows);
newDeviceData = new BufferEncapsulator<DeviceMemoryBuffer>(
this.offHeap.getDeviceData().data,
DeviceMemoryBuffer.allocate(validitySizeInBytes),
null);
Cuda.memset(newDeviceData.valid.getAddress(), (byte) 0x00, validitySizeInBytes);
} else if (this.offHeap.getDeviceData() != null) {
newDeviceData = this.offHeap.getDeviceData();
needsCleanup = false; // the data came from upstream
}
if (!scalar.isValid()) {
if (getNullCount() == getRowCount()) {
//current vector has all nulls, and we are trying to set it to null.
return;
}
this.nullCount = rows;

if (offHeap.getDeviceData().valid == null) {
long validitySizeInBytes = BitVectorHelper.getValidityAllocationSizeInBytes(rows);
// scalar is null and vector doesn't have a validity mask. Create a validity mask.
newDeviceData = new BufferEncapsulator<DeviceMemoryBuffer>(
this.offHeap.getDeviceData().data,
DeviceMemoryBuffer.allocate(validitySizeInBytes),
null);
this.offHeap.setDeviceData(newDeviceData);
} else {
newDeviceData = this.offHeap.getDeviceData();
}

// set the validity pointer
cudfColumnViewAugmented(
this.getNativeCudfColumnAddress(),
newDeviceData.data.address,
newDeviceData.valid != null ?
newDeviceData.valid.address : 0,
(int) this.getRowCount(),
this.getType().nativeId,
(int) this.getNullCount(),
this.getTimeUnit().getNativeId());
// the buffer encapsulator is the owner of newDeviceData, no need to clear
needsCleanup = false;

this.offHeap.setDeviceData(newDeviceData);
Cuda.memset(newDeviceData.valid.getAddress(), (byte) 0x00,
BitVectorHelper.getValidityLengthInBytes(rows));

// the column vector has the reference the BufferEncapsulator
// and can close later in case of failure
needsCleanup = false;
// set the validity pointer
cudfColumnViewAugmented(
this.getNativeCudfColumnAddress(),
newDeviceData.data.address,
newDeviceData.valid.address,
(int) this.getRowCount(),
this.getType().nativeId,
(int) this.nullCount,
this.getTimeUnit().getNativeId());

Cudf.fill(this, scalar);
} else {
this.nullCount = 0;
newDeviceData = this.offHeap.getDeviceData();
needsCleanup = false; // the data came from upstream

// the null_count could have changed, set it java-side
this.nullCount = getNullCount(getNativeCudfColumnAddress());
// if we are now setting the vector to a non-null, we need to
// close out the validity vector
if (newDeviceData.valid != null){
newDeviceData.valid.close();
newDeviceData = new BufferEncapsulator<DeviceMemoryBuffer>(
newDeviceData.data, null, null);
this.offHeap.setDeviceData(newDeviceData);
}

// set the validity pointer
cudfColumnViewAugmented(
this.getNativeCudfColumnAddress(),
newDeviceData.data.address,
0,
(int) this.getRowCount(),
this.getType().nativeId,
(int) nullCount,
this.getTimeUnit().getNativeId());

Cudf.fill(this, scalar);
}

// at this stage, host offHeap is no longer meaningful
// if we had hostData, reset it with a fresh copy from device
Expand Down Expand Up @@ -2030,40 +2059,32 @@ public static ColumnVector fromScalar(Scalar scalar, int rows) {
throw new IllegalArgumentException("STRING and STRING_CATEGORY are not supported scalars");
}
DeviceMemoryBuffer dataBuffer = null;
DeviceMemoryBuffer validityBuffer = null;
ColumnVector cv = null;
boolean needsCleanup = true;

try {
dataBuffer = DeviceMemoryBuffer.allocate(scalar.type.sizeInBytes * rows);

long validitySizeInBytes = BitVectorHelper.getValidityAllocationSizeInBytes(rows);

if (!scalar.isValid()) {
validityBuffer = DeviceMemoryBuffer.allocate(validitySizeInBytes);
// ensure this is all valid before calling cudf::fill, as before that
// the column is all valid
Cuda.memset(validityBuffer.getAddress(), (byte)0xFF, validitySizeInBytes);
}

cv = new ColumnVector(
scalar.getType(),
scalar.getTimeUnit(),
rows,
0,
dataBuffer,
validityBuffer,
null,
null, false);

// null this out as cv is the owner, and will be closed
// when cv closes in case of failure
dataBuffer = null;

cudfColumnViewAugmented(
cv.getNativeCudfColumnAddress(),
cv.offHeap.getDeviceData().data.address,
cv.offHeap.getDeviceData().valid != null ?
cv.offHeap.getDeviceData().valid.address :
0,
0,
(int) cv.getRowCount(),
cv.getType().nativeId,
(int) cv.getNullCount(),
0,
cv.getTimeUnit().getNativeId());

cv.fill(scalar);
Expand All @@ -2075,9 +2096,6 @@ public static ColumnVector fromScalar(Scalar scalar, int rows) {
if (dataBuffer != null) {
dataBuffer.close();
}
if (validityBuffer != null) {
validityBuffer.close();
}
if (cv != null) {
cv.close();
}
Expand Down
41 changes: 40 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ void testFromScalarFloat() {
}

@Test
void testFromScalarInteger() {
void testFromNullScalarInteger() {
assumeTrue(Cuda.isEnvCompatibleForTesting());
try (ColumnVector input = ColumnVector.fromScalar(Scalar.fromNull(DType.INT32), 6);
ColumnVector expected = ColumnVector.fromBoxedInts(null, null, null, null, null, null)) {
Expand All @@ -168,6 +168,45 @@ void testFromScalarInteger() {
}
}

@Test
void testSetToNullScalarInteger() {
assumeTrue(Cuda.isEnvCompatibleForTesting());
try (ColumnVector input = ColumnVector.fromScalar(Scalar.fromInt(123), 6);
ColumnVector expected = ColumnVector.fromBoxedInts(null, null, null, null, null, null)) {
input.fill(Scalar.fromNull(DType.INT32));
assertEquals(input.getNullCount(), expected.getNullCount());
assertColumnsAreEqual(input, expected);
}
}

@Test
void testSetToNullScalarByte() {
assumeTrue(Cuda.isEnvCompatibleForTesting());
int numNulls = 3000;
try (ColumnVector input = ColumnVector.fromScalar(Scalar.fromNull(DType.INT8), numNulls)) {
assertEquals(input.getNullCount(), numNulls);
input.ensureOnHost();
for (int i = 0; i < numNulls; i++){
assertTrue(input.isNull(i));
}
}
}

@Test
void testSetToNullThenBackScalarByte() {
assumeTrue(Cuda.isEnvCompatibleForTesting());
int numNulls = 3000;
try (ColumnVector input = ColumnVector.fromScalar(Scalar.fromNull(DType.INT8), numNulls)) {
assertEquals(input.getNullCount(), numNulls);
input.fill(Scalar.fromByte((byte)5));
assertEquals(input.getNullCount(), 0);
input.ensureOnHost();
for (int i = 0; i < numNulls; i++){
assertFalse(input.isNull(i));
}
}
}

@Test
void testFromScalarStringThrows() {
assumeTrue(Cuda.isEnvCompatibleForTesting());
Expand Down

0 comments on commit 080c05d

Please sign in to comment.