From ceeddd24a84d6b81ff215eb2486c0041ee2b76d2 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Fri, 19 Feb 2021 17:45:44 -0800 Subject: [PATCH 01/19] cast child columns for structs and lists --- .../java/ai/rapids/cudf/ColumnVector.java | 23 ++++ java/src/main/native/src/ColumnVectorJni.cpp | 129 ++++++++---------- 2 files changed, 83 insertions(+), 69 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 252f869a049..ca8b6cbf77b 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -27,6 +27,7 @@ import java.math.RoundingMode; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.function.Consumer; @@ -394,6 +395,24 @@ public static ColumnVector makeStruct(long rows, ColumnView... columns) { } } + /** + * Replace the leaf node in the list with the given column + */ + public static ColumnVector castLeafD64ToD32(ColumnView origList) { + assert(origList.type == DType.LIST); + return new ColumnVector(castLeafD64ToD32(origList.getNativeView())); + } + + /** + * Replace columns in the struct with the given columns + */ + public static ColumnVector replaceColumnsInStruct(ColumnView origStruct, int[] indices, + ColumnView[] views) { + assert(origStruct.type == DType.STRUCT); + return new ColumnVector(replaceColumnsInStruct(origStruct.getNativeView(), indices, + Arrays.stream(views).mapToLong( v -> v.getNativeView()).toArray())); + } + /** * Create a LIST column from the given columns. Each list in the returned column will have the * same number of entries in it as columns passed into this method. Be careful about the @@ -725,6 +744,10 @@ static void closeBuffers(AutoCloseable buffer) { private static native void setNativeNullCountColumn(long cudfColumnHandle, int nullCount) throws CudfException; + private static native long replaceColumnsInStruct(long cudfColumnHandle, + int[] indices, long[] viewHandles) throws CudfException; + + private static native long castLeafD64ToD32(long cudfColumnHandle) throws CudfException; /** * Create a cudf::column_view from a cudf::column. * @param cudfColumnHandle the pointer to the cudf::column diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index a1e8517c646..1e4470e1bf0 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include +#include #include #include #include @@ -27,9 +29,16 @@ #include #include #include +#include "cudf/null_mask.hpp" +#include "cudf/types.hpp" +#include "cudf/utilities/traits.hpp" +#include "cudf/unary.hpp" +#include "rmm/device_buffer.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" +#include "jni.h" +#include "jni_utils.hpp" extern "C" { @@ -317,92 +326,74 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeEmptyCudfColumn(JNI CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeNumericCudfColumn( - JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) { +cudf::column* replace_column(cudf::column list_column) { + cudf::lists_column_view lcv(list_column); - JNI_ARG_CHECK(env, (j_size != 0), "size is 0", 0); + std::unique_ptr new_child; - try { - cudf::jni::auto_set_device(env); - cudf::type_id n_type = static_cast(j_type); - cudf::data_type n_data_type(n_type); - cudf::size_type n_size = static_cast(j_size); - cudf::mask_state n_mask_state = static_cast(j_mask_state); - std::unique_ptr column( - cudf::make_numeric_column(n_data_type, n_size, n_mask_state)); - return reinterpret_cast(column.release()); + if (lcv.child().type().id() != cudf::type_id::LIST) { + assert(lcv.child().type() == cudf::type_id::DECIMAL64); + cudf::data_type to_type = cudf::data_type(cudf::type_id::DECIMAL32, lcv.child().type().scale()); + auto u_d32_ptr = cudf::cast(lcv.child(), to_type); + new_child.reset(u_d32_ptr.release()); + } else { + new_child.reset(replace_column(list_column.child(cudf::lists_column_view::child_column_index))); } - CATCH_STD(env, 0); + + assert(new_child->size() == contents.children[lists_column_view::child_column_index].size()); + int32_t size = list_column.size(); + int32_t null_count = list_column.null_count(); + auto contents = list_column.release(); + + auto col = cudf::make_lists_column(size, std::move(contents.children[cudf::lists_column_view::offsets_column_index]), + std::move(new_child), null_count, std::move(*contents.null_mask.release())); + return col.release(); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeTimestampCudfColumn( - JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castLeafD64ToD32( JNIEnv *env, jobject j_object, jlong j_handle) { - JNI_NULL_CHECK(env, j_type, "type id is null", 0); - JNI_NULL_CHECK(env, j_size, "size is null", 0); + JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); try { - cudf::jni::auto_set_device(env); - cudf::type_id n_type = static_cast(j_type); - std::unique_ptr n_data_type(new cudf::data_type(n_type)); - cudf::size_type n_size = static_cast(j_size); - cudf::mask_state n_mask_state = static_cast(j_mask_state); - std::unique_ptr column( - cudf::make_timestamp_column(*n_data_type.get(), n_size, n_mask_state)); - return reinterpret_cast(column.release()); + cudf::column_view *n_list_col_view = reinterpret_cast(j_handle); + JNI_ARG_CHECK(env, n_list_col_view->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0); + + auto copy_list = cudf::column(*n_list_col_view); + + return reinterpret_cast(replace_column(copy_list)); } CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeStringCudfColumnHostSide( - JNIEnv *env, jobject j_object, jlong j_char_data, jlong j_offset_data, jlong j_valid_data, - jint j_null_count, jint size) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_replaceColumnsInStruct( + JNIEnv *env, jobject j_object, jlong j_handle, jintArray j_indices, jlongArray j_children) { - JNI_ARG_CHECK(env, (size != 0), "size is 0", 0); - JNI_NULL_CHECK(env, j_char_data, "char data is null", 0); - JNI_NULL_CHECK(env, j_offset_data, "offset is null", 0); + JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); + JNI_NULL_CHECK(env, j_indices, "child indices to replace can't be null", 0); + JNI_NULL_CHECK(env, j_children, "children to replace can't be null", 0); try { - cudf::jni::auto_set_device(env); - cudf::size_type *host_offsets = reinterpret_cast(j_offset_data); - char *n_char_data = reinterpret_cast(j_char_data); - cudf::size_type n_data_size = host_offsets[size]; - cudf::bitmask_type *n_validity = reinterpret_cast(j_valid_data); - - if (n_validity == nullptr) { - j_null_count = 0; - } - - std::unique_ptr offsets = cudf::make_numeric_column( - cudf::data_type{cudf::type_id::INT32}, size + 1, cudf::mask_state::UNALLOCATED); - auto offsets_view = offsets->mutable_view(); - JNI_CUDA_TRY(env, 0, - cudaMemcpyAsync(offsets_view.data(), host_offsets, - (size + 1) * sizeof(int32_t), cudaMemcpyHostToDevice)); - - std::unique_ptr data = cudf::make_numeric_column( - cudf::data_type{cudf::type_id::INT8}, n_data_size, cudf::mask_state::UNALLOCATED); - auto data_view = data->mutable_view(); - JNI_CUDA_TRY(env, 0, - cudaMemcpyAsync(data_view.data(), n_char_data, n_data_size, - cudaMemcpyHostToDevice)); - - std::unique_ptr column; - if (j_null_count == 0) { - column = - cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count, {}); - } else { - cudf::size_type bytes = (cudf::word_index(size) + 1) * sizeof(cudf::bitmask_type); - rmm::device_buffer dev_validity(bytes); - JNI_CUDA_TRY(env, 0, - cudaMemcpyAsync(dev_validity.data(), n_validity, bytes, cudaMemcpyHostToDevice)); - - column = cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count, - std::move(dev_validity)); + cudf::jni::native_jpointerArray children_to_replace(env, j_children); + cudf::jni::native_jintArray indices(env, j_indices); + JNI_ARG_CHECK(env, indices.size() == children_to_replace.size(), "The indices size and children size should match", 0); + + cudf::column_view *n_struct_col_view = reinterpret_cast(j_handle); + JNI_ARG_CHECK(env, n_struct_col_view->type().id() == cudf::type_id::STRUCT, "Only struct types are allowed", 0); + + std::vector> children; + children.reserve(n_struct_col_view->num_children()); + int j = 0; + for (int i = 0 ; i < n_struct_col_view->num_children() ; i++) { + if (i == indices[j]) { + children.emplace_back(std::make_unique(*(children_to_replace[j++]))); + } else { + children.emplace_back(std::make_unique(n_struct_col_view->child(i))); + } } - JNI_CUDA_TRY(env, 0, cudaStreamSynchronize(0)); - return reinterpret_cast(column.release()); + auto col = cudf::make_structs_column(n_struct_col_view->size(), std::move(children), + n_struct_col_view->null_count(), cudf::copy_bitmask(*n_struct_col_view)); + return reinterpret_cast(col.release()); } CATCH_STD(env, 0); } From 78e633a7a0d7be6976016ed905c4bd4915e7a36c Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Sun, 21 Feb 2021 11:11:00 -0800 Subject: [PATCH 02/19] don't copy column --- .../main/java/ai/rapids/cudf/ColumnVector.java | 15 +++++++++++---- java/src/main/native/src/ColumnVectorJni.cpp | 8 +++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index ca8b6cbf77b..bb10ec72c5b 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -396,11 +396,18 @@ public static ColumnVector makeStruct(long rows, ColumnView... columns) { } /** - * Replace the leaf node in the list with the given column + * Cast Leaf node in the List col from a 64-bit Decimal to 32-bit Decimal, iff it exists. + * + * Ex 1: + * replace(col( type: List>>)) => returns unchanged + * + * Ex 2: + * replace(col(type: List>) => col(type: List>) + * */ - public static ColumnVector castLeafD64ToD32(ColumnView origList) { - assert(origList.type == DType.LIST); - return new ColumnVector(castLeafD64ToD32(origList.getNativeView())); + public ColumnVector castLeafD64ToD32() { + assert(type == DType.LIST); + return new ColumnVector(castLeafD64ToD32(offHeap.columnHandle)); } /** diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 1e4470e1bf0..7b8ba02a3a8 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -355,12 +355,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castLeafD64ToD32( JNIEn JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); try { - cudf::column_view *n_list_col_view = reinterpret_cast(j_handle); - JNI_ARG_CHECK(env, n_list_col_view->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0); + cudf::column *n_list_col = reinterpret_cast(j_handle); + JNI_ARG_CHECK(env, n_list_col->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0); - auto copy_list = cudf::column(*n_list_col_view); - - return reinterpret_cast(replace_column(copy_list)); + return reinterpret_cast(replace_column(*n_list_col)); } CATCH_STD(env, 0); } From 895433ce71642f440f0716fcf57ef24b7fc1e862 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Sun, 21 Feb 2021 16:42:55 -0800 Subject: [PATCH 03/19] replace cols in list and struct --- .../java/ai/rapids/cudf/ColumnVector.java | 22 +--- .../main/java/ai/rapids/cudf/ColumnView.java | 13 +++ java/src/main/native/src/ColumnVectorJni.cpp | 107 ++++++------------ java/src/main/native/src/ColumnViewJni.cpp | 40 +++++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 37 ++++++ 5 files changed, 133 insertions(+), 86 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index bb10ec72c5b..02bc7b3669d 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -396,13 +396,16 @@ public static ColumnVector makeStruct(long rows, ColumnView... columns) { } /** - * Cast Leaf node in the List col from a 64-bit Decimal to 32-bit Decimal, iff it exists. + * This is a very specialized method that has only one job. It takes in a list and returns a + * new list if the Leaf node is a 64-bit Decimal, converting the leaf to a 32-bit Decimal. + * Note: this is not a true cast as it assumes that the 64-bit Decimal column is a 32-bit Decimal + * column that happens to be stored as a 64-bit Decimal. * * Ex 1: - * replace(col( type: List>>)) => returns unchanged + * replace(col( type: List>>)) => throws an assert error * * Ex 2: - * replace(col(type: List>) => col(type: List>) + * replace(col(type: List>) => col(type: List>) no rounding is done * */ public ColumnVector castLeafD64ToD32() { @@ -410,16 +413,6 @@ public ColumnVector castLeafD64ToD32() { return new ColumnVector(castLeafD64ToD32(offHeap.columnHandle)); } - /** - * Replace columns in the struct with the given columns - */ - public static ColumnVector replaceColumnsInStruct(ColumnView origStruct, int[] indices, - ColumnView[] views) { - assert(origStruct.type == DType.STRUCT); - return new ColumnVector(replaceColumnsInStruct(origStruct.getNativeView(), indices, - Arrays.stream(views).mapToLong( v -> v.getNativeView()).toArray())); - } - /** * Create a LIST column from the given columns. Each list in the returned column will have the * same number of entries in it as columns passed into this method. Be careful about the @@ -751,9 +744,6 @@ static void closeBuffers(AutoCloseable buffer) { private static native void setNativeNullCountColumn(long cudfColumnHandle, int nullCount) throws CudfException; - private static native long replaceColumnsInStruct(long cudfColumnHandle, - int[] indices, long[] viewHandles) throws CudfException; - private static native long castLeafD64ToD32(long cudfColumnHandle) throws CudfException; /** * Create a cudf::column_view from a cudf::column. diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 1dce52f7105..2fe41ef1745 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -19,6 +19,7 @@ package ai.rapids.cudf; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -1296,6 +1297,16 @@ public ColumnVector castTo(DType type) { return new ColumnVector(castTo(getNativeView(), type.typeId.getNativeId(), type.getScale())); } + /** + * Replace columns in the struct with the given columns + */ + public ColumnVector replaceColumnsInStruct(int[] indices, + ColumnView[] views) { + assert(type == DType.STRUCT); + return new ColumnVector(replaceColumnsInStruct(getNativeView(), indices, + Arrays.stream(views).mapToLong(v -> v.getNativeView()).toArray())); + } + /** * Zero-copy cast between types with the same underlying representation. * @@ -2437,6 +2448,8 @@ static DeviceMemoryBufferView getOffsetsBuffer(long viewHandle) { */ private static native long timestampToStringTimestamp(long viewHandle, String format); + private static native long replaceColumnsInStruct(long cudfColumnHandle, + int[] indices, long[] viewHandles) throws CudfException; /** * Native method for locating the starting index of the first instance of a given substring * in each string in the column. 0 indexing, returns -1 if the substring is not found. Can be diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 7b8ba02a3a8..25a390c9c5e 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -43,6 +43,43 @@ extern "C" { +cudf::column* replace_column(cudf::column list_column) { + cudf::lists_column_view lcv(list_column); + + std::unique_ptr new_child; + + if (lcv.child().type().id() != cudf::type_id::LIST) { + assert(lcv.child().type() == cudf::type_id::DECIMAL64); + cudf::data_type to_type = cudf::data_type(cudf::type_id::DECIMAL32, lcv.child().type().scale()); + auto u_d32_ptr = cudf::cast(lcv.child(), to_type); + new_child.reset(u_d32_ptr.release()); + } else { + new_child.reset(replace_column(list_column.child(cudf::lists_column_view::child_column_index))); + } + + assert(new_child->size() == contents.children[lists_column_view::child_column_index].size()); + int32_t size = list_column.size(); + int32_t null_count = list_column.null_count(); + auto contents = list_column.release(); + + auto col = cudf::make_lists_column(size, std::move(contents.children[cudf::lists_column_view::offsets_column_index]), + std::move(new_child), null_count, std::move(*contents.null_mask.release())); + return col.release(); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castLeafD64ToD32( JNIEnv *env, jobject j_object, jlong j_handle) { + + JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); + + try { + cudf::column *n_list_col = reinterpret_cast(j_handle); + JNI_ARG_CHECK(env, n_list_col->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0); + + return reinterpret_cast(replace_column(*n_list_col)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass, jlong j_initial_val, jlong j_step, jint row_count) { @@ -326,76 +363,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeEmptyCudfColumn(JNI CATCH_STD(env, 0); } -cudf::column* replace_column(cudf::column list_column) { - cudf::lists_column_view lcv(list_column); - - std::unique_ptr new_child; - - if (lcv.child().type().id() != cudf::type_id::LIST) { - assert(lcv.child().type() == cudf::type_id::DECIMAL64); - cudf::data_type to_type = cudf::data_type(cudf::type_id::DECIMAL32, lcv.child().type().scale()); - auto u_d32_ptr = cudf::cast(lcv.child(), to_type); - new_child.reset(u_d32_ptr.release()); - } else { - new_child.reset(replace_column(list_column.child(cudf::lists_column_view::child_column_index))); - } - - assert(new_child->size() == contents.children[lists_column_view::child_column_index].size()); - int32_t size = list_column.size(); - int32_t null_count = list_column.null_count(); - auto contents = list_column.release(); - - auto col = cudf::make_lists_column(size, std::move(contents.children[cudf::lists_column_view::offsets_column_index]), - std::move(new_child), null_count, std::move(*contents.null_mask.release())); - return col.release(); -} - -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castLeafD64ToD32( JNIEnv *env, jobject j_object, jlong j_handle) { - - JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); - - try { - cudf::column *n_list_col = reinterpret_cast(j_handle); - JNI_ARG_CHECK(env, n_list_col->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0); - - return reinterpret_cast(replace_column(*n_list_col)); - } - CATCH_STD(env, 0); -} - -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_replaceColumnsInStruct( - JNIEnv *env, jobject j_object, jlong j_handle, jintArray j_indices, jlongArray j_children) { - - JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); - JNI_NULL_CHECK(env, j_indices, "child indices to replace can't be null", 0); - JNI_NULL_CHECK(env, j_children, "children to replace can't be null", 0); - - try { - cudf::jni::native_jpointerArray children_to_replace(env, j_children); - cudf::jni::native_jintArray indices(env, j_indices); - JNI_ARG_CHECK(env, indices.size() == children_to_replace.size(), "The indices size and children size should match", 0); - - cudf::column_view *n_struct_col_view = reinterpret_cast(j_handle); - JNI_ARG_CHECK(env, n_struct_col_view->type().id() == cudf::type_id::STRUCT, "Only struct types are allowed", 0); - - std::vector> children; - children.reserve(n_struct_col_view->num_children()); - int j = 0; - for (int i = 0 ; i < n_struct_col_view->num_children() ; i++) { - if (i == indices[j]) { - children.emplace_back(std::make_unique(*(children_to_replace[j++]))); - } else { - children.emplace_back(std::make_unique(n_struct_col_view->child(i))); - } - } - - auto col = cudf::make_structs_column(n_struct_col_view->size(), std::move(children), - n_struct_col_view->null_count(), cudf::copy_bitmask(*n_struct_col_view)); - return reinterpret_cast(col.release()); - } - CATCH_STD(env, 0); -} - JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeNullCountColumn(JNIEnv *env, jobject j_object, jlong handle) { diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 82e71b04a2f..fb386682210 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -60,6 +60,7 @@ #include #include #include +#include "cudf/types.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" @@ -1760,4 +1761,43 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_copyColumnViewToCV(JNIEnv } CATCH_STD(env, 0) } + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceColumnsInStruct( + JNIEnv *env, jobject j_object, jlong j_handle, jintArray j_indices, jlongArray j_children) { + + JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); + JNI_NULL_CHECK(env, j_indices, "child indices to replace can't be null", 0); + JNI_NULL_CHECK(env, j_children, "children to replace can't be null", 0); + + try { + cudf::jni::native_jpointerArray children_to_replace(env, j_children); + cudf::jni::native_jintArray indices(env, j_indices); + JNI_ARG_CHECK(env, indices.size() == children_to_replace.size(), "The indices size and children size should match", 0); + + cudf::column_view *n_struct_col_view = reinterpret_cast(j_handle); + JNI_ARG_CHECK(env, n_struct_col_view->type().id() == cudf::type_id::STRUCT, "Only struct types are allowed", 0); + + std::map m; + for (int i = 0 ; i < indices.size() ; i++) { + m[indices[i]] = children_to_replace[i]; + } + + std::vector> children; + children.reserve(n_struct_col_view->num_children()); + int j = 0; + for (int i = 0 ; i < n_struct_col_view->num_children() ; i++) { + auto it = m.find(i); + if (it != m.end()) { + children.emplace_back(std::make_unique(*it->second)); + } else { + children.emplace_back(std::make_unique(n_struct_col_view->child(i))); + } + } + + auto col = cudf::make_structs_column(n_struct_col_view->size(), std::move(children), + n_struct_col_view->null_count(), cudf::copy_bitmask(*n_struct_col_view)); + return reinterpret_cast(col.release()); + } + CATCH_STD(env, 0); +} } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index cb1f792b99e..ba5212da587 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3936,4 +3936,41 @@ void testMakeList() { assertColumnsAreEqual(expected, created); } } + + @Test + void testCastLeafNodeInList() { + try ( + ColumnVector c1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 770.892, 961.110); + ColumnVector c2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 524.982, 479.946); + ColumnVector c3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 346.997, 479.946); + ColumnVector c4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 87.764, 414.239); + ColumnVector expected = ColumnVector.makeList(c1, c2, c3, c4); + ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); + ColumnVector child2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 524.982, 479.946); + ColumnVector child3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 346.997, 479.946); + ColumnVector child4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 87.764, 414.239); + ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); + ColumnVector replaced = created.castLeafD64ToD32()) { + assertColumnsAreEqual(expected, replaced); + } + } + + @Test + void testReplaceColumnInStruct() { + try (ColumnVector expected = ColumnVector.fromStructs(new StructType(false, + Arrays.asList( + new BasicType(false, DType.INT32), + new BasicType(false, DType.INT32), + new BasicType(false, DType.INT32))), + new HostColumnVector.StructData(1, 5, 3), + new HostColumnVector.StructData(4, 9, 6)); + ColumnVector child1 = ColumnVector.fromInts(1, 4); + ColumnVector child2 = ColumnVector.fromInts(2, 5); + ColumnVector child3 = ColumnVector.fromInts(3, 6); + ColumnVector created = ColumnVector.makeStruct(child1, child2, child3); + ColumnVector replaceWith = ColumnVector.fromInts(5, 9); + ColumnVector replaced = created.replaceColumnsInStruct(new int[]{1}, new ColumnVector[]{replaceWith})) { + assertColumnsAreEqual(expected, replaced); + } + } } From 160d06c6128904f045de5eb0f20b145d926de239 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Fri, 26 Feb 2021 16:38:39 -0800 Subject: [PATCH 04/19] updated method names --- java/src/main/java/ai/rapids/cudf/ColumnVector.java | 2 +- java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 02bc7b3669d..2b315d8bf04 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -408,7 +408,7 @@ public static ColumnVector makeStruct(long rows, ColumnView... columns) { * replace(col(type: List>) => col(type: List>) no rounding is done * */ - public ColumnVector castLeafD64ToD32() { + public ColumnVector castLeafDecimal64ToDecimal32() { assert(type == DType.LIST); return new ColumnVector(castLeafD64ToD32(offHeap.columnHandle)); } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index ba5212da587..95e4eec740c 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3950,7 +3950,7 @@ void testCastLeafNodeInList() { ColumnVector child3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 346.997, 479.946); ColumnVector child4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 87.764, 414.239); ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); - ColumnVector replaced = created.castLeafD64ToD32()) { + ColumnVector replaced = created.castLeafDecimal64ToDecimal32()) { assertColumnsAreEqual(expected, replaced); } } From c82534fd94bf904079e06d8167ebdaf451cabf88 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 1 Mar 2021 14:39:33 -0800 Subject: [PATCH 05/19] addressed review comments --- .../java/ai/rapids/cudf/ColumnVector.java | 18 ------------ .../main/java/ai/rapids/cudf/ColumnView.java | 18 +++++++++--- java/src/main/native/src/ColumnVectorJni.cpp | 14 +++++----- java/src/main/native/src/ColumnViewJni.cpp | 28 +++++++++++-------- .../java/ai/rapids/cudf/ColumnVectorTest.java | 23 +++++++++------ 5 files changed, 52 insertions(+), 49 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 2b315d8bf04..cc382065489 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -395,24 +395,6 @@ public static ColumnVector makeStruct(long rows, ColumnView... columns) { } } - /** - * This is a very specialized method that has only one job. It takes in a list and returns a - * new list if the Leaf node is a 64-bit Decimal, converting the leaf to a 32-bit Decimal. - * Note: this is not a true cast as it assumes that the 64-bit Decimal column is a 32-bit Decimal - * column that happens to be stored as a 64-bit Decimal. - * - * Ex 1: - * replace(col( type: List>>)) => throws an assert error - * - * Ex 2: - * replace(col(type: List>) => col(type: List>) no rounding is done - * - */ - public ColumnVector castLeafDecimal64ToDecimal32() { - assert(type == DType.LIST); - return new ColumnVector(castLeafD64ToD32(offHeap.columnHandle)); - } - /** * Create a LIST column from the given columns. Each list in the returned column will have the * same number of entries in it as columns passed into this method. Be careful about the diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 2fe41ef1745..c7a7545f336 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1298,15 +1298,25 @@ public ColumnVector castTo(DType type) { } /** - * Replace columns in the struct with the given columns + * This method takes in a nested type and replaces its children with the given views */ - public ColumnVector replaceColumnsInStruct(int[] indices, + public ColumnView replaceChildrenWithViews(int[] indices, ColumnView[] views) { assert(type == DType.STRUCT); - return new ColumnVector(replaceColumnsInStruct(getNativeView(), indices, + return new ColumnView(replaceChildrenWithViews(getNativeView(), indices, Arrays.stream(views).mapToLong(v -> v.getNativeView()).toArray())); } + /** + * This method takes in a list and returns a new list with the leaf node replaced with the given + * view. The number of rows in the new child should be the same as the original otherwise the + * resulting list will have bad data + */ + public ColumnView replaceListChild(ColumnView child) { + assert(type == DType.LIST); + return replaceChildrenWithViews(new int[]{1}, new ColumnView[]{child}); + } + /** * Zero-copy cast between types with the same underlying representation. * @@ -2448,7 +2458,7 @@ static DeviceMemoryBufferView getOffsetsBuffer(long viewHandle) { */ private static native long timestampToStringTimestamp(long viewHandle, String format); - private static native long replaceColumnsInStruct(long cudfColumnHandle, + private static native long replaceChildrenWithViews(long cudfColumnHandle, int[] indices, long[] viewHandles) throws CudfException; /** * Native method for locating the starting index of the first instance of a given substring diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 25a390c9c5e..78c01cc4278 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -43,8 +43,8 @@ extern "C" { -cudf::column* replace_column(cudf::column list_column) { - cudf::lists_column_view lcv(list_column); +cudf::column* replace_column(cudf::column* list_column) { + cudf::lists_column_view lcv(list_column->view()); std::unique_ptr new_child; @@ -54,13 +54,13 @@ cudf::column* replace_column(cudf::column list_column) { auto u_d32_ptr = cudf::cast(lcv.child(), to_type); new_child.reset(u_d32_ptr.release()); } else { - new_child.reset(replace_column(list_column.child(cudf::lists_column_view::child_column_index))); + new_child.reset(replace_column(&list_column->child(cudf::lists_column_view::child_column_index))); } assert(new_child->size() == contents.children[lists_column_view::child_column_index].size()); - int32_t size = list_column.size(); - int32_t null_count = list_column.null_count(); - auto contents = list_column.release(); + int32_t size = list_column->size(); + int32_t null_count = list_column->null_count(); + auto contents = list_column->release(); auto col = cudf::make_lists_column(size, std::move(contents.children[cudf::lists_column_view::offsets_column_index]), std::move(new_child), null_count, std::move(*contents.null_mask.release())); @@ -75,7 +75,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castLeafD64ToD32( JNIEn cudf::column *n_list_col = reinterpret_cast(j_handle); JNI_ARG_CHECK(env, n_list_col->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0); - return reinterpret_cast(replace_column(*n_list_col)); + return reinterpret_cast(replace_column(n_list_col)); } CATCH_STD(env, 0); } diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index fb386682210..5f2a1debd6a 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -64,6 +65,7 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" +#include "jni_utils.hpp" namespace { @@ -1762,7 +1764,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_copyColumnViewToCV(JNIEnv CATCH_STD(env, 0) } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceColumnsInStruct( +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceChildrenWithViews( JNIEnv *env, jobject j_object, jlong j_handle, jintArray j_indices, jlongArray j_children) { JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); @@ -1774,30 +1776,34 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceColumnsInStruct( cudf::jni::native_jintArray indices(env, j_indices); JNI_ARG_CHECK(env, indices.size() == children_to_replace.size(), "The indices size and children size should match", 0); - cudf::column_view *n_struct_col_view = reinterpret_cast(j_handle); - JNI_ARG_CHECK(env, n_struct_col_view->type().id() == cudf::type_id::STRUCT, "Only struct types are allowed", 0); + cudf::column_view *n_col_view = reinterpret_cast(j_handle); + cudf::type_id id = n_col_view->type().id(); + JNI_ARG_CHECK(env, id == cudf::type_id::STRUCT || id == cudf::type_id::LIST, "Only nested types are allowed", 0); + if (id == cudf::type_id::LIST) { + JNI_ARG_CHECK(env, children_to_replace.size() == 1, "LIST can only have one child to replace", 0); + } std::map m; for (int i = 0 ; i < indices.size() ; i++) { m[indices[i]] = children_to_replace[i]; } - std::vector> children; - children.reserve(n_struct_col_view->num_children()); + std::vector children; + children.reserve(n_col_view->num_children()); int j = 0; - for (int i = 0 ; i < n_struct_col_view->num_children() ; i++) { + for (int i = 0 ; i < n_col_view->num_children() ; i++) { auto it = m.find(i); if (it != m.end()) { - children.emplace_back(std::make_unique(*it->second)); + children.emplace_back(*it->second); } else { - children.emplace_back(std::make_unique(n_struct_col_view->child(i))); + children.emplace_back(n_col_view->child(i)); } } - auto col = cudf::make_structs_column(n_struct_col_view->size(), std::move(children), - n_struct_col_view->null_count(), cudf::copy_bitmask(*n_struct_col_view)); - return reinterpret_cast(col.release()); + std::unique_ptr n_new_nested(new cudf::column_view(n_col_view->type(), n_col_view->size(), nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), children)); + return reinterpret_cast(n_new_nested.release()); } CATCH_STD(env, 0); } + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 95e4eec740c..44ffffbdc6f 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3938,20 +3938,23 @@ void testMakeList() { } @Test - void testCastLeafNodeInList() { + void testReplaceLeafNodeInList() { try ( - ColumnVector c1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 770.892, 961.110); - ColumnVector c2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 524.982, 479.946); - ColumnVector c3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 346.997, 479.946); - ColumnVector c4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 87.764, 414.239); + ColumnVector c1 = ColumnVector.fromInts(1, 2); + ColumnVector c2 = ColumnVector.fromInts(8, 3); + ColumnVector c3 = ColumnVector.fromInts(9, 8); + ColumnVector c4 = ColumnVector.fromInts(2, 6); ColumnVector expected = ColumnVector.makeList(c1, c2, c3, c4); ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); ColumnVector child2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 524.982, 479.946); ColumnVector child3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 346.997, 479.946); ColumnVector child4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 87.764, 414.239); ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); - ColumnVector replaced = created.castLeafDecimal64ToDecimal32()) { - assertColumnsAreEqual(expected, replaced); + ColumnVector newChild = ColumnVector.fromInts(1, 8, 9, 2, 2, 3, 8, 6); + ColumnView replacedView = created.replaceListChild(newChild)) { + try (ColumnVector replaced = replacedView.copyToColumnVector()) { + assertColumnsAreEqual(expected, replaced); + } } } @@ -3969,8 +3972,10 @@ void testReplaceColumnInStruct() { ColumnVector child3 = ColumnVector.fromInts(3, 6); ColumnVector created = ColumnVector.makeStruct(child1, child2, child3); ColumnVector replaceWith = ColumnVector.fromInts(5, 9); - ColumnVector replaced = created.replaceColumnsInStruct(new int[]{1}, new ColumnVector[]{replaceWith})) { - assertColumnsAreEqual(expected, replaced); + ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1}, new ColumnVector[]{replaceWith})) { + try (ColumnVector replaced = replacedView.copyToColumnVector()) { + assertColumnsAreEqual(expected, replaced); + } } } } From 2836e8614ef34eaa868c60a33e3d2990e6af1d6a Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 1 Mar 2021 14:43:01 -0800 Subject: [PATCH 06/19] code cleanup --- .../java/ai/rapids/cudf/ColumnVector.java | 2 - .../main/java/ai/rapids/cudf/ColumnView.java | 2 +- java/src/main/native/src/ColumnVectorJni.cpp | 47 ------------------- 3 files changed, 1 insertion(+), 50 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index cc382065489..252f869a049 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -27,7 +27,6 @@ import java.math.RoundingMode; import java.nio.ByteBuffer; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.function.Consumer; @@ -726,7 +725,6 @@ static void closeBuffers(AutoCloseable buffer) { private static native void setNativeNullCountColumn(long cudfColumnHandle, int nullCount) throws CudfException; - private static native long castLeafD64ToD32(long cudfColumnHandle) throws CudfException; /** * Create a cudf::column_view from a cudf::column. * @param cudfColumnHandle the pointer to the cudf::column diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index c7a7545f336..be25b1a4e72 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1302,7 +1302,7 @@ public ColumnVector castTo(DType type) { */ public ColumnView replaceChildrenWithViews(int[] indices, ColumnView[] views) { - assert(type == DType.STRUCT); + assert(type == DType.STRUCT || type == DType.LIST); return new ColumnView(replaceChildrenWithViews(getNativeView(), indices, Arrays.stream(views).mapToLong(v -> v.getNativeView()).toArray())); } diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 78c01cc4278..0b5e6aab539 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -14,8 +14,6 @@ * limitations under the License. */ -#include -#include #include #include #include @@ -29,57 +27,12 @@ #include #include #include -#include "cudf/null_mask.hpp" -#include "cudf/types.hpp" -#include "cudf/utilities/traits.hpp" -#include "cudf/unary.hpp" -#include "rmm/device_buffer.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" -#include "jni.h" -#include "jni_utils.hpp" - extern "C" { -cudf::column* replace_column(cudf::column* list_column) { - cudf::lists_column_view lcv(list_column->view()); - - std::unique_ptr new_child; - - if (lcv.child().type().id() != cudf::type_id::LIST) { - assert(lcv.child().type() == cudf::type_id::DECIMAL64); - cudf::data_type to_type = cudf::data_type(cudf::type_id::DECIMAL32, lcv.child().type().scale()); - auto u_d32_ptr = cudf::cast(lcv.child(), to_type); - new_child.reset(u_d32_ptr.release()); - } else { - new_child.reset(replace_column(&list_column->child(cudf::lists_column_view::child_column_index))); - } - - assert(new_child->size() == contents.children[lists_column_view::child_column_index].size()); - int32_t size = list_column->size(); - int32_t null_count = list_column->null_count(); - auto contents = list_column->release(); - - auto col = cudf::make_lists_column(size, std::move(contents.children[cudf::lists_column_view::offsets_column_index]), - std::move(new_child), null_count, std::move(*contents.null_mask.release())); - return col.release(); -} - -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castLeafD64ToD32( JNIEnv *env, jobject j_object, jlong j_handle) { - - JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); - - try { - cudf::column *n_list_col = reinterpret_cast(j_handle); - JNI_ARG_CHECK(env, n_list_col->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0); - - return reinterpret_cast(replace_column(n_list_col)); - } - CATCH_STD(env, 0); -} - JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass, jlong j_initial_val, jlong j_step, jint row_count) { From c8a35c6cca539ab1542060503cbe643c0e525e89 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 1 Mar 2021 14:49:25 -0800 Subject: [PATCH 07/19] reformatted code --- java/src/main/native/src/ColumnViewJni.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 5f2a1debd6a..f60d3f996d2 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1800,7 +1800,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceChildrenWithViews( } } - std::unique_ptr n_new_nested(new cudf::column_view(n_col_view->type(), n_col_view->size(), nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), children)); + std::unique_ptr n_new_nested(new cudf::column_view(n_col_view->type(), n_col_view->size(), + nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), children)); return reinterpret_cast(n_new_nested.release()); } CATCH_STD(env, 0); From 4be9b22125355a4a65dda176a7c96658d3667031 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Mon, 1 Mar 2021 16:54:22 -0800 Subject: [PATCH 08/19] addressed review comments --- .../main/java/ai/rapids/cudf/ColumnView.java | 39 +++++++++++++++-- java/src/main/native/src/ColumnViewJni.cpp | 10 +++-- .../java/ai/rapids/cudf/ColumnVectorTest.java | 42 +++++++++++++++++++ 3 files changed, 84 insertions(+), 7 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index be25b1a4e72..17458b88201 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1299,18 +1299,51 @@ public ColumnVector castTo(DType type) { /** * This method takes in a nested type and replaces its children with the given views + * Note: Make sure the numbers of rows in the leaf node are the same as the child replacing it + * otherwise the list can point to elements outside of the column values. + * + * Ex: List list = col{{1,3}, {9,3,5}} + * + * validNewChild = col{8, 3, 9, 2, 0} + * + * list.replaceChildrenWithViews(1, validNewChild) => col{{8, 3}, {9, 2, 0}} + * + * invalidNewChild = col{3, 2} + * list.replaceChildrenWithViews(1, invalidNewChild) => col{{3, 2}, {invalid, invalid, invalid}} + * + * invalidNewChild = col{8, 3, 9, 2, 0, 0, 7} + * list.replaceChildrenWithViews(1, invalidNewChild) => col{{8, 3}, {9, 2, 0}} // undefined result */ public ColumnView replaceChildrenWithViews(int[] indices, ColumnView[] views) { - assert(type == DType.STRUCT || type == DType.LIST); + assert(type.isNestedType()); + assert(indices.length == views.length); + if (type == DType.LIST) { + assert(indices.length == 1); + } + return new ColumnView(replaceChildrenWithViews(getNativeView(), indices, Arrays.stream(views).mapToLong(v -> v.getNativeView()).toArray())); } /** * This method takes in a list and returns a new list with the leaf node replaced with the given - * view. The number of rows in the new child should be the same as the original otherwise the - * resulting list will have bad data + * view. Make sure the numbers of rows in the leaf node are the same as the child replacing it + * otherwise the list can point to elements outside of the column values. + * + * Ex: List list = col{{1,3}, {9,3,5}} + * + * validNewChild = col{8, 3, 9, 2, 0} + * + * list.replaceChildrenWithViews(1, validNewChild) => col{{8, 3}, {9, 2, 0}} + * + * invalidNewChild = col{3, 2} + * list.replaceChildrenWithViews(1, invalidNewChild) => + * col{{3, 2}, {invalid, invalid, invalid}} throws an exception + * + * invalidNewChild = col{8, 3, 9, 2, 0, 0, 7} + * list.replaceChildrenWithViews(1, invalidNewChild) => + * col{{8, 3}, {9, 2, 0}} throws an exception */ public ColumnView replaceListChild(ColumnView child) { assert(type == DType.LIST); diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index f60d3f996d2..0b88215cf5e 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1778,13 +1778,11 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceChildrenWithViews( cudf::column_view *n_col_view = reinterpret_cast(j_handle); cudf::type_id id = n_col_view->type().id(); - JNI_ARG_CHECK(env, id == cudf::type_id::STRUCT || id == cudf::type_id::LIST, "Only nested types are allowed", 0); - if (id == cudf::type_id::LIST) { - JNI_ARG_CHECK(env, children_to_replace.size() == 1, "LIST can only have one child to replace", 0); - } std::map m; for (int i = 0 ; i < indices.size() ; i++) { + auto it = m.find(indices[i]); + JNI_ARG_CHECK(env, it == m.end(), "Duplicate mapping found for replacing child index", 0); m[indices[i]] = children_to_replace[i]; } @@ -1794,12 +1792,16 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceChildrenWithViews( for (int i = 0 ; i < n_col_view->num_children() ; i++) { auto it = m.find(i); if (it != m.end()) { + JNI_ARG_CHECK(env, (*it->second).size() == n_col_view->child(i).size(), "Child size don't match", 0); + m.erase(it); children.emplace_back(*it->second); } else { children.emplace_back(n_col_view->child(i)); } } + JNI_ARG_CHECK(env, m.empty(), "One or more invalid child indices passed to be replaced", 0); + std::unique_ptr n_new_nested(new cudf::column_view(n_col_view->type(), n_col_view->size(), nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), children)); return reinterpret_cast(n_new_nested.release()); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 44ffffbdc6f..fd2e510a8d6 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3958,6 +3958,20 @@ void testReplaceLeafNodeInList() { } } + @Test + void testReplaceLeafNodeInListWithIllegal() { + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); + ColumnVector child2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 524.982, 479.946); + ColumnVector child3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 346.997, 479.946); + ColumnVector child4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 87.764, 414.239); + ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); + ColumnVector newChild = ColumnVector.fromInts(0, 1, 8, 9, 2, 2, 3, 8, 6); + ColumnView replacedView = created.replaceListChild(newChild)) { + } + }); + } + @Test void testReplaceColumnInStruct() { try (ColumnVector expected = ColumnVector.fromStructs(new StructType(false, @@ -3978,4 +3992,32 @@ void testReplaceColumnInStruct() { } } } + + @Test + void testReplaceIllegalIndexColumnInStruct() { + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector child1 = ColumnVector.fromInts(1, 4); + ColumnVector child2 = ColumnVector.fromInts(2, 5); + ColumnVector child3 = ColumnVector.fromInts(3, 6); + ColumnVector created = ColumnVector.makeStruct(child1, child2, child3); + ColumnVector replaceWith = ColumnVector.fromInts(5, 9); + ColumnView replacedView = created.replaceChildrenWithViews(new int[]{5}, + new ColumnVector[]{replaceWith})) { + } + }); + } + + @Test + void testReplaceSameIndexColumnInStruct() { + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector child1 = ColumnVector.fromInts(1, 4); + ColumnVector child2 = ColumnVector.fromInts(2, 5); + ColumnVector child3 = ColumnVector.fromInts(3, 6); + ColumnVector created = ColumnVector.makeStruct(child1, child2, child3); + ColumnVector replaceWith = ColumnVector.fromInts(5, 9); + ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1, 1}, + new ColumnVector[]{replaceWith, replaceWith})) { + } + }); + } } From 5e79b2120e61dc6ff6f7f7edb4904651545e629d Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 2 Mar 2021 13:21:17 -0800 Subject: [PATCH 09/19] updated method doc --- java/src/main/java/ai/rapids/cudf/ColumnView.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 17458b88201..139465030af 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1302,6 +1302,9 @@ public ColumnVector castTo(DType type) { * Note: Make sure the numbers of rows in the leaf node are the same as the child replacing it * otherwise the list can point to elements outside of the column values. * + * Note: this method returns a ColumnView and it won't live past the ColumnVector that its + * pointing to. + * * Ex: List list = col{{1,3}, {9,3,5}} * * validNewChild = col{8, 3, 9, 2, 0} @@ -1331,6 +1334,9 @@ public ColumnView replaceChildrenWithViews(int[] indices, * view. Make sure the numbers of rows in the leaf node are the same as the child replacing it * otherwise the list can point to elements outside of the column values. * + * Note: this method returns a ColumnView and it won't live past the ColumnVector that its + * pointing to. + * * Ex: List list = col{{1,3}, {9,3,5}} * * validNewChild = col{8, 3, 9, 2, 0} From 0518c6ed3a4223382e3f156278e00ad216dbd7d3 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 2 Mar 2021 16:50:23 -0800 Subject: [PATCH 10/19] addressed review comments --- java/src/main/native/src/ColumnViewJni.cpp | 10 ++-- .../java/ai/rapids/cudf/ColumnVectorTest.java | 47 +++++++++++++------ 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 0b88215cf5e..9fb1ea4adb9 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1786,24 +1786,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceChildrenWithViews( m[indices[i]] = children_to_replace[i]; } - std::vector children; - children.reserve(n_col_view->num_children()); + std::vector new_children; + new_children.reserve(n_col_view->num_children()); int j = 0; for (int i = 0 ; i < n_col_view->num_children() ; i++) { auto it = m.find(i); if (it != m.end()) { JNI_ARG_CHECK(env, (*it->second).size() == n_col_view->child(i).size(), "Child size don't match", 0); + new_children.emplace_back(*it->second); m.erase(it); - children.emplace_back(*it->second); } else { - children.emplace_back(n_col_view->child(i)); + new_children.emplace_back(n_col_view->child(i)); } } JNI_ARG_CHECK(env, m.empty(), "One or more invalid child indices passed to be replaced", 0); std::unique_ptr n_new_nested(new cudf::column_view(n_col_view->type(), n_col_view->size(), - nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), children)); + nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), new_children)); return reinterpret_cast(n_new_nested.release()); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index fd2e510a8d6..74ff2c635a1 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3945,13 +3945,21 @@ void testReplaceLeafNodeInList() { ColumnVector c3 = ColumnVector.fromInts(9, 8); ColumnVector c4 = ColumnVector.fromInts(2, 6); ColumnVector expected = ColumnVector.makeList(c1, c2, c3, c4); - ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); - ColumnVector child2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 524.982, 479.946); - ColumnVector child3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 346.997, 479.946); - ColumnVector child4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 87.764, 414.239); - ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); - ColumnVector newChild = ColumnVector.fromInts(1, 8, 9, 2, 2, 3, 8, 6); - ColumnView replacedView = created.replaceListChild(newChild)) { + ColumnVector child1 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 770.892, 961.110); + ColumnVector child2 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 524.982, 479.946); + ColumnVector child3 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 346.997, 479.946); + ColumnVector child4 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 87.764, 414.239); + ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); + ColumnVector newChild = ColumnVector.fromInts(1, 8, 9, 2, 2, 3, 8, 6); + ColumnView replacedView = created.replaceListChild(newChild)) { try (ColumnVector replaced = replacedView.copyToColumnVector()) { assertColumnsAreEqual(expected, replaced); } @@ -3961,13 +3969,21 @@ void testReplaceLeafNodeInList() { @Test void testReplaceLeafNodeInListWithIllegal() { assertThrows(IllegalArgumentException.class, () -> { - try (ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); - ColumnVector child2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 524.982, 479.946); - ColumnVector child3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 346.997, 479.946); - ColumnVector child4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 87.764, 414.239); - ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); - ColumnVector newChild = ColumnVector.fromInts(0, 1, 8, 9, 2, 2, 3, 8, 6); - ColumnView replacedView = created.replaceListChild(newChild)) { + try (ColumnVector child1 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 770.892, 961.110); + ColumnVector child2 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 524.982, 479.946); + ColumnVector child3 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 346.997, 479.946); + ColumnVector child4 = + ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), + RoundingMode.HALF_UP, 87.764, 414.239); + ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4); + ColumnVector newChild = ColumnVector.fromInts(0, 1, 8, 9, 2, 2, 3, 8, 6); + ColumnView replacedView = created.replaceListChild(newChild)) { } }); } @@ -3986,7 +4002,8 @@ void testReplaceColumnInStruct() { ColumnVector child3 = ColumnVector.fromInts(3, 6); ColumnVector created = ColumnVector.makeStruct(child1, child2, child3); ColumnVector replaceWith = ColumnVector.fromInts(5, 9); - ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1}, new ColumnVector[]{replaceWith})) { + ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1}, + new ColumnVector[]{replaceWith})) { try (ColumnVector replaced = replacedView.copyToColumnVector()) { assertColumnsAreEqual(expected, replaced); } From 1e17bdfa51e158740deb1e6be04c31a5b348e0cd Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 2 Mar 2021 16:51:56 -0800 Subject: [PATCH 11/19] Update java/src/main/java/ai/rapids/cudf/ColumnView.java Co-authored-by: MithunR --- java/src/main/java/ai/rapids/cudf/ColumnView.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 139465030af..a4fd8acc779 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1302,7 +1302,7 @@ public ColumnVector castTo(DType type) { * Note: Make sure the numbers of rows in the leaf node are the same as the child replacing it * otherwise the list can point to elements outside of the column values. * - * Note: this method returns a ColumnView and it won't live past the ColumnVector that its + * Note: this method returns a ColumnView that won't live past the ColumnVector that it's * pointing to. * * Ex: List list = col{{1,3}, {9,3,5}} From 84d463c93e3a9afec51df5f1a6fe8d87d6ecf440 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 2 Mar 2021 16:52:05 -0800 Subject: [PATCH 12/19] Update java/src/main/java/ai/rapids/cudf/ColumnView.java Co-authored-by: MithunR --- java/src/main/java/ai/rapids/cudf/ColumnView.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index a4fd8acc779..61ae4b6619c 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1334,7 +1334,7 @@ public ColumnView replaceChildrenWithViews(int[] indices, * view. Make sure the numbers of rows in the leaf node are the same as the child replacing it * otherwise the list can point to elements outside of the column values. * - * Note: this method returns a ColumnView and it won't live past the ColumnVector that its + * Note: this method returns a ColumnView that won't live past the ColumnVector that it's * pointing to. * * Ex: List list = col{{1,3}, {9,3,5}} From beb05a42b8cbf31d3c20506a88406e47561f0d1d Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 2 Mar 2021 16:52:16 -0800 Subject: [PATCH 13/19] Update java/src/main/native/src/ColumnViewJni.cpp Co-authored-by: MithunR --- java/src/main/native/src/ColumnViewJni.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 9fb1ea4adb9..506cf055bb9 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -61,7 +61,7 @@ #include #include #include -#include "cudf/types.hpp" +#include #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" From 6ed07cb45bc12717520802cc72b762cd3a5478f5 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 2 Mar 2021 16:55:35 -0800 Subject: [PATCH 14/19] fixed the return --- java/src/main/native/src/ColumnViewJni.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 506cf055bb9..08b06958ba2 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1802,9 +1802,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceChildrenWithViews( JNI_ARG_CHECK(env, m.empty(), "One or more invalid child indices passed to be replaced", 0); - std::unique_ptr n_new_nested(new cudf::column_view(n_col_view->type(), n_col_view->size(), + return reinterpret_cast(new cudf::column_view(n_col_view->type(), n_col_view->size(), nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), new_children)); - return reinterpret_cast(n_new_nested.release()); } CATCH_STD(env, 0); } From 6abc62941b37b5c113932377022219fe4450beb2 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 4 Mar 2021 10:32:37 -0800 Subject: [PATCH 15/19] d32 changes --- .../java/ai/rapids/cudf/ColumnVector.java | 10 +-- .../main/java/ai/rapids/cudf/ColumnView.java | 90 ++++++++++++++++--- java/src/main/native/src/ColumnViewJni.cpp | 44 --------- 3 files changed, 83 insertions(+), 61 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 252f869a049..1dc7ea6911f 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -167,16 +167,16 @@ private static long getColumnViewFromColumn(long nativePointer) { } } - - private static long initViewHandle(DType type, int rows, int nc, DeviceMemoryBuffer dataBuffer, - DeviceMemoryBuffer validityBuffer, - DeviceMemoryBuffer offsetBuffer, long[] childHandles) { + protected static long initViewHandle(DType type, int rows, int nc, + BaseDeviceMemoryBuffer dataBuffer, + BaseDeviceMemoryBuffer validityBuffer, + BaseDeviceMemoryBuffer offsetBuffer, long[] childHandles) { long cd = dataBuffer == null ? 0 : dataBuffer.address; long cdSize = dataBuffer == null ? 0 : dataBuffer.length; long od = offsetBuffer == null ? 0 : offsetBuffer.address; long vd = validityBuffer == null ? 0 : validityBuffer.address; return makeCudfColumnView(type.typeId.getNativeId(), type.getScale(), cd, cdSize, - od, vd, nc, rows, childHandles) ; + od, vd, nc, rows, childHandles); } static ColumnVector fromViewWithContiguousAllocation(long columnViewAddress, DeviceMemoryBuffer buffer) { diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 61ae4b6619c..6a0527f96e6 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -18,10 +18,10 @@ package ai.rapids.cudf; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; +import javafx.util.Pair; + +import java.util.*; +import java.util.stream.IntStream; import static ai.rapids.cudf.HostColumnVector.OFFSET_SIZE; @@ -50,6 +50,53 @@ protected ColumnView(long address) { this.nullCount = ColumnView.getNativeNullCount(viewHandle); } + /** + * Create a new column vector based off of data already on the device. + * @param type the type of the vector + * @param rows the number of rows in this vector. + * @param nullCount the number of nulls in the dataset. + * @param validityBuffer an optional validity buffer. Must be provided if nullCount != 0. + * The ownership doesn't change on this buffer + * @param offsetBuffer a host buffer required for nested types including strings and string + * categories. The ownership doesn't change on this buffer + * @param children an array of ColumnView children + */ + public ColumnView(DType type, long rows, Optional nullCount, + BaseDeviceMemoryBuffer validityBuffer, + BaseDeviceMemoryBuffer offsetBuffer, ColumnView[] children) { + this(type, (int) rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), + null, validityBuffer, offsetBuffer, children); + assert(type.isNestedType()); + assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE) + || !nullCount.isPresent(); + } + + /** + * Create a new column vector based off of data already on the device. + * @param type the type of the vector + * @param rows the number of rows in this vector. + * @param nullCount the number of nulls in the dataset. + * @param dataBuffer a host buffer required for nested types including strings and string + * categories. The ownership doesn't change on this buffer + * @param validityBuffer an optional validity buffer. Must be provided if nullCount != 0. + * The ownership doesn't change on this buffer + */ + public ColumnView(DType type, long rows, Optional nullCount, + BaseDeviceMemoryBuffer dataBuffer, + BaseDeviceMemoryBuffer validityBuffer) { + this(type, (int) rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), + dataBuffer, validityBuffer, null, null); + assert (!type.isNestedType()); + assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE) + || !nullCount.isPresent(); + } + + private ColumnView(DType type, long rows, int nullCount, + BaseDeviceMemoryBuffer dataBuffer, BaseDeviceMemoryBuffer validityBuffer, + BaseDeviceMemoryBuffer offsetBuffer, ColumnView[] children) { + this(ColumnVector.initViewHandle(type, (int) rows, nullCount, dataBuffer, validityBuffer, + offsetBuffer, Arrays.stream(children).mapToLong(c -> c.getNativeView()).toArray())); + } /** Creates a ColumnVector from a column view handle * @return a new ColumnVector @@ -1319,14 +1366,35 @@ public ColumnVector castTo(DType type) { */ public ColumnView replaceChildrenWithViews(int[] indices, ColumnView[] views) { - assert(type.isNestedType()); - assert(indices.length == views.length); + assert (type.isNestedType()); + assert (indices.length == views.length); if (type == DType.LIST) { - assert(indices.length == 1); + assert (indices.length == 1); } - - return new ColumnView(replaceChildrenWithViews(getNativeView(), indices, - Arrays.stream(views).mapToLong(v -> v.getNativeView()).toArray())); + if (indices.length != views.length) { + throw new IllegalArgumentException("The indices size and children size should match"); + } + Map map = new HashMap<>(); + IntStream.range(0, indices.length).forEach(index -> { + if (map.containsKey(indices[index])) { + throw new IllegalArgumentException("Duplicate mapping found for replacing child index"); + } + map.put(indices[index], views[index]); + }); + List newChildren = new ArrayList<>(getNumChildren()); + IntStream.range(0, getNumChildren()).forEach(i -> { + ColumnView view = map.remove(i); + if (view == null) { + newChildren.add(getChildColumnView(i)); + } else { + newChildren.add(view); + } + }); + if (!map.isEmpty()) { + throw new IllegalArgumentException("One or more invalid child indices passed to be replaced"); + } + return new ColumnView(type, getRowCount(), Optional.of(getNullCount()), getValid(), + getOffsets(), views); } /** @@ -2497,8 +2565,6 @@ static DeviceMemoryBufferView getOffsetsBuffer(long viewHandle) { */ private static native long timestampToStringTimestamp(long viewHandle, String format); - private static native long replaceChildrenWithViews(long cudfColumnHandle, - int[] indices, long[] viewHandles) throws CudfException; /** * Native method for locating the starting index of the first instance of a given substring * in each string in the column. 0 indexing, returns -1 if the substring is not found. Can be diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 08b06958ba2..81889aa4da7 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1764,48 +1764,4 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_copyColumnViewToCV(JNIEnv CATCH_STD(env, 0) } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceChildrenWithViews( - JNIEnv *env, jobject j_object, jlong j_handle, jintArray j_indices, jlongArray j_children) { - - JNI_NULL_CHECK(env, j_handle, "native handle is null", 0); - JNI_NULL_CHECK(env, j_indices, "child indices to replace can't be null", 0); - JNI_NULL_CHECK(env, j_children, "children to replace can't be null", 0); - - try { - cudf::jni::native_jpointerArray children_to_replace(env, j_children); - cudf::jni::native_jintArray indices(env, j_indices); - JNI_ARG_CHECK(env, indices.size() == children_to_replace.size(), "The indices size and children size should match", 0); - - cudf::column_view *n_col_view = reinterpret_cast(j_handle); - cudf::type_id id = n_col_view->type().id(); - - std::map m; - for (int i = 0 ; i < indices.size() ; i++) { - auto it = m.find(indices[i]); - JNI_ARG_CHECK(env, it == m.end(), "Duplicate mapping found for replacing child index", 0); - m[indices[i]] = children_to_replace[i]; - } - - std::vector new_children; - new_children.reserve(n_col_view->num_children()); - int j = 0; - for (int i = 0 ; i < n_col_view->num_children() ; i++) { - auto it = m.find(i); - if (it != m.end()) { - JNI_ARG_CHECK(env, (*it->second).size() == n_col_view->child(i).size(), "Child size don't match", 0); - new_children.emplace_back(*it->second); - m.erase(it); - } else { - new_children.emplace_back(n_col_view->child(i)); - } - } - - JNI_ARG_CHECK(env, m.empty(), "One or more invalid child indices passed to be replaced", 0); - - return reinterpret_cast(new cudf::column_view(n_col_view->type(), n_col_view->size(), - nullptr, n_col_view->null_mask(), n_col_view->null_count(), n_col_view->offset(), new_children)); - } - CATCH_STD(env, 0); -} - } // extern "C" From 2937b3153f9303bb062d5ddd87c29e4700203eb9 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 4 Mar 2021 13:23:49 -0800 Subject: [PATCH 16/19] code cleanup --- java/src/main/java/ai/rapids/cudf/ColumnView.java | 10 ++++++---- java/src/main/native/src/ColumnViewJni.cpp | 4 ---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 4aec834bb3f..9d37efca4e4 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -18,8 +18,6 @@ package ai.rapids.cudf; -import javafx.util.Pair; - import java.util.*; import java.util.stream.IntStream; @@ -51,7 +49,9 @@ protected ColumnView(long address) { } /** - * Create a new column vector based off of data already on the device. + * Create a new column view based off of data already on the device. Ref count on the buffers + * is not incremented and none of the underlying buffers are owned by this view. If ownership + * is needed, call {@link ColumnView#copyToColumnVector} * @param type the type of the vector * @param rows the number of rows in this vector. * @param nullCount the number of nulls in the dataset. @@ -72,7 +72,9 @@ public ColumnView(DType type, long rows, Optional nullCount, } /** - * Create a new column vector based off of data already on the device. + * Create a new column view based off of data already on the device. Ref count on the buffers + * is not incremented and none of the underlying buffers are owned by this view. If ownership + * is needed, call {@link ColumnView#copyToColumnVector} * @param type the type of the vector * @param rows the number of rows in this vector. * @param nullCount the number of nulls in the dataset. diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index e1690496961..a0613f9b73f 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include #include @@ -61,11 +60,9 @@ #include #include #include -#include #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" -#include "jni_utils.hpp" namespace { @@ -1763,5 +1760,4 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_copyColumnViewToCV(JNIEnv } CATCH_STD(env, 0) } - } // extern "C" From ee866b25f6638f9df075567994f70eada93cc433 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Sat, 6 Mar 2021 22:12:52 -0800 Subject: [PATCH 17/19] addressed review comments --- .../main/java/ai/rapids/cudf/ColumnVector.java | 2 +- .../main/java/ai/rapids/cudf/ColumnView.java | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 0929c12df93..9f414661967 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -167,7 +167,7 @@ private static long getColumnViewFromColumn(long nativePointer) { } } - protected static long initViewHandle(DType type, int rows, int nc, + static long initViewHandle(DType type, int rows, int nc, BaseDeviceMemoryBuffer dataBuffer, BaseDeviceMemoryBuffer validityBuffer, BaseDeviceMemoryBuffer offsetBuffer, long[] childHandles) { diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 9d37efca4e4..099f36e65de 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -50,8 +50,12 @@ protected ColumnView(long address) { /** * Create a new column view based off of data already on the device. Ref count on the buffers - * is not incremented and none of the underlying buffers are owned by this view. If ownership - * is needed, call {@link ColumnView#copyToColumnVector} + * is not incremented and none of the underlying buffers are owned by this view. The returned + * ColumnView is only valid as long as the underlying buffers remain valid. If the buffers are + * closed before this ColumnView is closed, it will result in undefined behavior. + * + * If ownership is needed, call {@link ColumnView#copyToColumnVector} + * * @param type the type of the vector * @param rows the number of rows in this vector. * @param nullCount the number of nulls in the dataset. @@ -73,8 +77,12 @@ public ColumnView(DType type, long rows, Optional nullCount, /** * Create a new column view based off of data already on the device. Ref count on the buffers - * is not incremented and none of the underlying buffers are owned by this view. If ownership - * is needed, call {@link ColumnView#copyToColumnVector} + * is not incremented and none of the underlying buffers are owned by this view. The returned + * ColumnView is only valid as long as the underlying buffers remain valid. If the buffers are + * closed before this ColumnView is closed, it will result in undefined behavior. + * + * If ownership is needed, call {@link ColumnView#copyToColumnVector} + * * @param type the type of the vector * @param rows the number of rows in this vector. * @param nullCount the number of nulls in the dataset. @@ -1396,7 +1404,7 @@ public ColumnView replaceChildrenWithViews(int[] indices, throw new IllegalArgumentException("One or more invalid child indices passed to be replaced"); } return new ColumnView(type, getRowCount(), Optional.of(getNullCount()), getValid(), - getOffsets(), views); + getOffsets(), newChildren.stream().toArray(n -> new ColumnView[n])); } /** From b67060fb760de31785b32001d0ec597377f0e574 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 9 Mar 2021 13:09:27 -0800 Subject: [PATCH 18/19] fix build --- java/src/main/java/ai/rapids/cudf/ColumnView.java | 6 ++++-- java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index e0cc96263b3..2df89177d5f 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1394,9 +1394,11 @@ public ColumnView replaceChildrenWithViews(int[] indices, List newChildren = new ArrayList<>(getNumChildren()); IntStream.range(0, getNumChildren()).forEach(i -> { ColumnView view = map.remove(i); + ColumnView child = getChildColumnView(i); if (view == null) { - newChildren.add(getChildColumnView(i)); + newChildren.add(child); } else { + assert (child.getRowCount() == view.getRowCount()); newChildren.add(view); } }); @@ -1431,7 +1433,7 @@ public ColumnView replaceChildrenWithViews(int[] indices, */ public ColumnView replaceListChild(ColumnView child) { assert(type == DType.LIST); - return replaceChildrenWithViews(new int[]{1}, new ColumnView[]{child}); + return replaceChildrenWithViews(new int[]{0}, new ColumnView[]{child}); } /** diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 0675ece4863..554a1fd0629 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4005,7 +4005,7 @@ void testReplaceLeafNodeInList() { @Test void testReplaceLeafNodeInListWithIllegal() { - assertThrows(IllegalArgumentException.class, () -> { + assertThrows(AssertionError.class, () -> { try (ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); From 3adc86b6becb5298e8b9b82a3f3360eabb9d359e Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 9 Mar 2021 14:04:26 -0800 Subject: [PATCH 19/19] addressed review comments --- java/src/main/java/ai/rapids/cudf/ColumnView.java | 4 +++- java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java | 9 ++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 2df89177d5f..f36896a3c96 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1398,7 +1398,9 @@ public ColumnView replaceChildrenWithViews(int[] indices, if (view == null) { newChildren.add(child); } else { - assert (child.getRowCount() == view.getRowCount()); + if (child.getRowCount() != view.getRowCount()) { + throw new IllegalArgumentException("Child row count doesn't match the old child"); + } newChildren.add(view); } }); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 554a1fd0629..d224543e574 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4005,7 +4005,7 @@ void testReplaceLeafNodeInList() { @Test void testReplaceLeafNodeInListWithIllegal() { - assertThrows(AssertionError.class, () -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> { try (ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110); @@ -4023,6 +4023,7 @@ void testReplaceLeafNodeInListWithIllegal() { ColumnView replacedView = created.replaceListChild(newChild)) { } }); + assertTrue(e.getMessage().contains("Child row count doesn't match the old child")); } @Test @@ -4049,7 +4050,7 @@ void testReplaceColumnInStruct() { @Test void testReplaceIllegalIndexColumnInStruct() { - assertThrows(IllegalArgumentException.class, () -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> { try (ColumnVector child1 = ColumnVector.fromInts(1, 4); ColumnVector child2 = ColumnVector.fromInts(2, 5); ColumnVector child3 = ColumnVector.fromInts(3, 6); @@ -4059,11 +4060,12 @@ void testReplaceIllegalIndexColumnInStruct() { new ColumnVector[]{replaceWith})) { } }); + assertTrue(e.getMessage().contains("One or more invalid child indices passed to be replaced")); } @Test void testReplaceSameIndexColumnInStruct() { - assertThrows(IllegalArgumentException.class, () -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> { try (ColumnVector child1 = ColumnVector.fromInts(1, 4); ColumnVector child2 = ColumnVector.fromInts(2, 5); ColumnVector child3 = ColumnVector.fromInts(3, 6); @@ -4073,5 +4075,6 @@ void testReplaceSameIndexColumnInStruct() { new ColumnVector[]{replaceWith, replaceWith})) { } }); + assertTrue(e.getMessage().contains("Duplicate mapping found for replacing child index")); } }