Skip to content

Commit

Permalink
Java support for casting of nested child columns (#7417)
Browse files Browse the repository at this point in the history
This PR adds a couple of very specialized methods that help us cast columns inside nested types.

Authors:
  - Raza Jafri (@razajafri)

Approvers:
  - Robert (Bobby) Evans (@revans2)
  - Jason Lowe (@jlowe)
  - MithunR (@mythrocks)

URL: #7417
  • Loading branch information
razajafri authored Mar 8, 2021
1 parent 9618a81 commit 9017f22
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 99 deletions.
10 changes: 5 additions & 5 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,16 @@ private static long getColumnViewFromColumn(long nativePointer) {
}
}


private static long initViewHandle(DType type, int rows, int nc, DeviceMemoryBuffer dataBuffer,
DeviceMemoryBuffer validityBuffer,
DeviceMemoryBuffer offsetBuffer, long[] childHandles) {
static long initViewHandle(DType type, int rows, int nc,
BaseDeviceMemoryBuffer dataBuffer,
BaseDeviceMemoryBuffer validityBuffer,
BaseDeviceMemoryBuffer offsetBuffer, long[] childHandles) {
long cd = dataBuffer == null ? 0 : dataBuffer.address;
long cdSize = dataBuffer == null ? 0 : dataBuffer.length;
long od = offsetBuffer == null ? 0 : offsetBuffer.address;
long vd = validityBuffer == null ? 0 : validityBuffer.address;
return makeCudfColumnView(type.typeId.getNativeId(), type.getScale(), cd, cdSize,
od, vd, nc, rows, childHandles) ;
od, vd, nc, rows, childHandles);
}

static ColumnVector fromViewWithContiguousAllocation(long columnViewAddress, DeviceMemoryBuffer buffer) {
Expand Down
144 changes: 141 additions & 3 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@

package ai.rapids.cudf;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.stream.IntStream;

import static ai.rapids.cudf.HostColumnVector.OFFSET_SIZE;

Expand Down Expand Up @@ -49,6 +48,65 @@ protected ColumnView(long address) {
this.nullCount = ColumnView.getNativeNullCount(viewHandle);
}

/**
* Create a new column view based off of data already on the device. Ref count on the buffers
* is not incremented and none of the underlying buffers are owned by this view. The returned
* ColumnView is only valid as long as the underlying buffers remain valid. If the buffers are
* closed before this ColumnView is closed, it will result in undefined behavior.
*
* If ownership is needed, call {@link ColumnView#copyToColumnVector}
*
* @param type the type of the vector
* @param rows the number of rows in this vector.
* @param nullCount the number of nulls in the dataset.
* @param validityBuffer an optional validity buffer. Must be provided if nullCount != 0.
* The ownership doesn't change on this buffer
* @param offsetBuffer a host buffer required for nested types including strings and string
* categories. The ownership doesn't change on this buffer
* @param children an array of ColumnView children
*/
public ColumnView(DType type, long rows, Optional<Long> nullCount,
BaseDeviceMemoryBuffer validityBuffer,
BaseDeviceMemoryBuffer offsetBuffer, ColumnView[] children) {
this(type, (int) rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
null, validityBuffer, offsetBuffer, children);
assert(type.isNestedType());
assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE)
|| !nullCount.isPresent();
}

/**
* Create a new column view based off of data already on the device. Ref count on the buffers
* is not incremented and none of the underlying buffers are owned by this view. The returned
* ColumnView is only valid as long as the underlying buffers remain valid. If the buffers are
* closed before this ColumnView is closed, it will result in undefined behavior.
*
* If ownership is needed, call {@link ColumnView#copyToColumnVector}
*
* @param type the type of the vector
* @param rows the number of rows in this vector.
* @param nullCount the number of nulls in the dataset.
* @param dataBuffer a host buffer required for nested types including strings and string
* categories. The ownership doesn't change on this buffer
* @param validityBuffer an optional validity buffer. Must be provided if nullCount != 0.
* The ownership doesn't change on this buffer
*/
public ColumnView(DType type, long rows, Optional<Long> nullCount,
BaseDeviceMemoryBuffer dataBuffer,
BaseDeviceMemoryBuffer validityBuffer) {
this(type, (int) rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
dataBuffer, validityBuffer, null, null);
assert (!type.isNestedType());
assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE)
|| !nullCount.isPresent();
}

private ColumnView(DType type, long rows, int nullCount,
BaseDeviceMemoryBuffer dataBuffer, BaseDeviceMemoryBuffer validityBuffer,
BaseDeviceMemoryBuffer offsetBuffer, ColumnView[] children) {
this(ColumnVector.initViewHandle(type, (int) rows, nullCount, dataBuffer, validityBuffer,
offsetBuffer, Arrays.stream(children).mapToLong(c -> c.getNativeView()).toArray()));
}

/** Creates a ColumnVector from a column view handle
* @return a new ColumnVector
Expand Down Expand Up @@ -1296,6 +1354,86 @@ public ColumnVector castTo(DType type) {
return new ColumnVector(castTo(getNativeView(), type.typeId.getNativeId(), type.getScale()));
}

/**
* This method takes in a nested type and replaces its children with the given views
* Note: Make sure the numbers of rows in the leaf node are the same as the child replacing it
* otherwise the list can point to elements outside of the column values.
*
* Note: this method returns a ColumnView that won't live past the ColumnVector that it's
* pointing to.
*
* Ex: List<Int> list = col{{1,3}, {9,3,5}}
*
* validNewChild = col{8, 3, 9, 2, 0}
*
* list.replaceChildrenWithViews(1, validNewChild) => col{{8, 3}, {9, 2, 0}}
*
* invalidNewChild = col{3, 2}
* list.replaceChildrenWithViews(1, invalidNewChild) => col{{3, 2}, {invalid, invalid, invalid}}
*
* invalidNewChild = col{8, 3, 9, 2, 0, 0, 7}
* list.replaceChildrenWithViews(1, invalidNewChild) => col{{8, 3}, {9, 2, 0}} // undefined result
*/
public ColumnView replaceChildrenWithViews(int[] indices,
ColumnView[] views) {
assert (type.isNestedType());
assert (indices.length == views.length);
if (type == DType.LIST) {
assert (indices.length == 1);
}
if (indices.length != views.length) {
throw new IllegalArgumentException("The indices size and children size should match");
}
Map<Integer, ColumnView> map = new HashMap<>();
IntStream.range(0, indices.length).forEach(index -> {
if (map.containsKey(indices[index])) {
throw new IllegalArgumentException("Duplicate mapping found for replacing child index");
}
map.put(indices[index], views[index]);
});
List<ColumnView> newChildren = new ArrayList<>(getNumChildren());
IntStream.range(0, getNumChildren()).forEach(i -> {
ColumnView view = map.remove(i);
if (view == null) {
newChildren.add(getChildColumnView(i));
} else {
newChildren.add(view);
}
});
if (!map.isEmpty()) {
throw new IllegalArgumentException("One or more invalid child indices passed to be replaced");
}
return new ColumnView(type, getRowCount(), Optional.of(getNullCount()), getValid(),
getOffsets(), newChildren.stream().toArray(n -> new ColumnView[n]));
}

/**
* This method takes in a list and returns a new list with the leaf node replaced with the given
* view. Make sure the numbers of rows in the leaf node are the same as the child replacing it
* otherwise the list can point to elements outside of the column values.
*
* Note: this method returns a ColumnView that won't live past the ColumnVector that it's
* pointing to.
*
* Ex: List<Int> list = col{{1,3}, {9,3,5}}
*
* validNewChild = col{8, 3, 9, 2, 0}
*
* list.replaceChildrenWithViews(1, validNewChild) => col{{8, 3}, {9, 2, 0}}
*
* invalidNewChild = col{3, 2}
* list.replaceChildrenWithViews(1, invalidNewChild) =>
* col{{3, 2}, {invalid, invalid, invalid}} throws an exception
*
* invalidNewChild = col{8, 3, 9, 2, 0, 0, 7}
* list.replaceChildrenWithViews(1, invalidNewChild) =>
* col{{8, 3}, {9, 2, 0}} throws an exception
*/
public ColumnView replaceListChild(ColumnView child) {
assert(type == DType.LIST);
return replaceChildrenWithViews(new int[]{1}, new ColumnView[]{child});
}

/**
* Zero-copy cast between types with the same underlying representation.
*
Expand Down
91 changes: 0 additions & 91 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"


extern "C" {

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass,
Expand Down Expand Up @@ -315,96 +314,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeEmptyCudfColumn(JNI
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeNumericCudfColumn(
JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) {

JNI_ARG_CHECK(env, (j_size != 0), "size is 0", 0);

try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
cudf::data_type n_data_type(n_type);
cudf::size_type n_size = static_cast<cudf::size_type>(j_size);
cudf::mask_state n_mask_state = static_cast<cudf::mask_state>(j_mask_state);
std::unique_ptr<cudf::column> column(
cudf::make_numeric_column(n_data_type, n_size, n_mask_state));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeTimestampCudfColumn(
JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) {

JNI_NULL_CHECK(env, j_type, "type id is null", 0);
JNI_NULL_CHECK(env, j_size, "size is null", 0);

try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
std::unique_ptr<cudf::data_type> n_data_type(new cudf::data_type(n_type));
cudf::size_type n_size = static_cast<cudf::size_type>(j_size);
cudf::mask_state n_mask_state = static_cast<cudf::mask_state>(j_mask_state);
std::unique_ptr<cudf::column> column(
cudf::make_timestamp_column(*n_data_type.get(), n_size, n_mask_state));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeStringCudfColumnHostSide(
JNIEnv *env, jobject j_object, jlong j_char_data, jlong j_offset_data, jlong j_valid_data,
jint j_null_count, jint size) {

JNI_ARG_CHECK(env, (size != 0), "size is 0", 0);
JNI_NULL_CHECK(env, j_char_data, "char data is null", 0);
JNI_NULL_CHECK(env, j_offset_data, "offset is null", 0);

try {
cudf::jni::auto_set_device(env);
cudf::size_type *host_offsets = reinterpret_cast<cudf::size_type *>(j_offset_data);
char *n_char_data = reinterpret_cast<char *>(j_char_data);
cudf::size_type n_data_size = host_offsets[size];
cudf::bitmask_type *n_validity = reinterpret_cast<cudf::bitmask_type *>(j_valid_data);

if (n_validity == nullptr) {
j_null_count = 0;
}

std::unique_ptr<cudf::column> offsets = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, size + 1, cudf::mask_state::UNALLOCATED);
auto offsets_view = offsets->mutable_view();
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(offsets_view.data<int32_t>(), host_offsets,
(size + 1) * sizeof(int32_t), cudaMemcpyHostToDevice));

std::unique_ptr<cudf::column> data = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT8}, n_data_size, cudf::mask_state::UNALLOCATED);
auto data_view = data->mutable_view();
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(data_view.data<int8_t>(), n_char_data, n_data_size,
cudaMemcpyHostToDevice));

std::unique_ptr<cudf::column> column;
if (j_null_count == 0) {
column =
cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count, {});
} else {
cudf::size_type bytes = (cudf::word_index(size) + 1) * sizeof(cudf::bitmask_type);
rmm::device_buffer dev_validity(bytes);
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(dev_validity.data(), n_validity, bytes, cudaMemcpyHostToDevice));

column = cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count,
std::move(dev_validity));
}

JNI_CUDA_TRY(env, 0, cudaStreamSynchronize(0));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeNullCountColumn(JNIEnv *env,
jobject j_object,
jlong handle) {
Expand Down
101 changes: 101 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3951,4 +3951,105 @@ void testMakeList() {
assertColumnsAreEqual(expected, created);
}
}

@Test
void testReplaceLeafNodeInList() {
try (
ColumnVector c1 = ColumnVector.fromInts(1, 2);
ColumnVector c2 = ColumnVector.fromInts(8, 3);
ColumnVector c3 = ColumnVector.fromInts(9, 8);
ColumnVector c4 = ColumnVector.fromInts(2, 6);
ColumnVector expected = ColumnVector.makeList(c1, c2, c3, c4);
ColumnVector child1 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 770.892, 961.110);
ColumnVector child2 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 524.982, 479.946);
ColumnVector child3 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 346.997, 479.946);
ColumnVector child4 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 87.764, 414.239);
ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4);
ColumnVector newChild = ColumnVector.fromInts(1, 8, 9, 2, 2, 3, 8, 6);
ColumnView replacedView = created.replaceListChild(newChild)) {
try (ColumnVector replaced = replacedView.copyToColumnVector()) {
assertColumnsAreEqual(expected, replaced);
}
}
}

@Test
void testReplaceLeafNodeInListWithIllegal() {
assertThrows(IllegalArgumentException.class, () -> {
try (ColumnVector child1 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 770.892, 961.110);
ColumnVector child2 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 524.982, 479.946);
ColumnVector child3 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 346.997, 479.946);
ColumnVector child4 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 87.764, 414.239);
ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4);
ColumnVector newChild = ColumnVector.fromInts(0, 1, 8, 9, 2, 2, 3, 8, 6);
ColumnView replacedView = created.replaceListChild(newChild)) {
}
});
}

@Test
void testReplaceColumnInStruct() {
try (ColumnVector expected = ColumnVector.fromStructs(new StructType(false,
Arrays.asList(
new BasicType(false, DType.INT32),
new BasicType(false, DType.INT32),
new BasicType(false, DType.INT32))),
new HostColumnVector.StructData(1, 5, 3),
new HostColumnVector.StructData(4, 9, 6));
ColumnVector child1 = ColumnVector.fromInts(1, 4);
ColumnVector child2 = ColumnVector.fromInts(2, 5);
ColumnVector child3 = ColumnVector.fromInts(3, 6);
ColumnVector created = ColumnVector.makeStruct(child1, child2, child3);
ColumnVector replaceWith = ColumnVector.fromInts(5, 9);
ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1},
new ColumnVector[]{replaceWith})) {
try (ColumnVector replaced = replacedView.copyToColumnVector()) {
assertColumnsAreEqual(expected, replaced);
}
}
}

@Test
void testReplaceIllegalIndexColumnInStruct() {
assertThrows(IllegalArgumentException.class, () -> {
try (ColumnVector child1 = ColumnVector.fromInts(1, 4);
ColumnVector child2 = ColumnVector.fromInts(2, 5);
ColumnVector child3 = ColumnVector.fromInts(3, 6);
ColumnVector created = ColumnVector.makeStruct(child1, child2, child3);
ColumnVector replaceWith = ColumnVector.fromInts(5, 9);
ColumnView replacedView = created.replaceChildrenWithViews(new int[]{5},
new ColumnVector[]{replaceWith})) {
}
});
}

@Test
void testReplaceSameIndexColumnInStruct() {
assertThrows(IllegalArgumentException.class, () -> {
try (ColumnVector child1 = ColumnVector.fromInts(1, 4);
ColumnVector child2 = ColumnVector.fromInts(2, 5);
ColumnVector child3 = ColumnVector.fromInts(3, 6);
ColumnVector created = ColumnVector.makeStruct(child1, child2, child3);
ColumnVector replaceWith = ColumnVector.fromInts(5, 9);
ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1, 1},
new ColumnVector[]{replaceWith, replaceWith})) {
}
});
}
}

0 comments on commit 9017f22

Please sign in to comment.