From 67efaf62c7314fbda56d39ea5f60bb893c406084 Mon Sep 17 00:00:00 2001 From: Nghia Truong <7416935+ttnghia@users.noreply.github.com> Date: Wed, 14 Jun 2023 14:32:56 -0700 Subject: [PATCH] Add JNI for `lists::concatenate_list_elements` (#13547) This implements JNI work for `lists::concatenate_list_elements` to expose the API to Java usage. Authors: - Nghia Truong (https://github.com/ttnghia) Approvers: - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/13547 --- .../main/java/ai/rapids/cudf/ColumnView.java | 27 ++++++++++++++ java/src/main/native/src/ColumnViewJni.cpp | 15 ++++++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 36 +++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 85bc8d7715a..4bbd719ce0d 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2491,6 +2491,31 @@ public final ColumnVector dropListDuplicatesWithKeysValues() { return new ColumnVector(dropListDuplicatesWithKeysValues(getNativeView())); } + /** + * Flatten each list of lists into a single list. + * + * The column must have rows that are lists of lists. + * Any row containing null list elements will result in a null output row. + * + * @return A new column vector containing the flattened result + */ + public ColumnVector flattenLists() { + return flattenLists(false); + } + + /** + * Flatten each list of lists into a single list. + * + * The column must have rows that are lists of lists. + * + * @param ignoreNull Whether to ignore null list elements in the input column from the operation, + * or any row containing null list elements will result in a null output row + * @return A new column vector containing the flattened result + */ + public ColumnVector flattenLists(boolean ignoreNull) { + return new ColumnVector(flattenLists(getNativeView(), ignoreNull)); + } + ///////////////////////////////////////////////////////////////////////////// // STRINGS ///////////////////////////////////////////////////////////////////////////// @@ -4467,6 +4492,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long dropListDuplicatesWithKeysValues(long nativeHandle); + private static native long flattenLists(long inputHandle, boolean ignoreNull); + /** * Native method for list lookup * @param nativeView the column view handle of the list diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 673dd8ae42a..1cb51a22bf3 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -504,6 +505,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_flattenLists(JNIEnv *env, jclass, + jlong input_handle, + jboolean ignore_null) { + JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE : + cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW; + auto const input_cv = reinterpret_cast(input_handle); + return release_as_jlong(cudf::lists::concatenate_list_elements(*input_cv, null_policy)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContains(JNIEnv *env, jclass, jlong column_view, jlong lookup_key) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 31bfaa1e828..e1da4b6a1ea 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2884,6 +2884,42 @@ void testListConcatByRowIgnoreNull() { } } + @Test + void testFlattenLists() { + HostColumnVector.ListType listType = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + HostColumnVector.ListType listOfListsType = new HostColumnVector.ListType(true, listType); + + // Input does not have nulls. + try (ColumnVector input = ColumnVector.fromLists(listOfListsType, + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3), Arrays.asList(4, 5, 6)), + Arrays.asList(Arrays.asList(7, 8, 9), Arrays.asList(10, 11, 12, 13, 14, 15))); + ColumnVector result = input.flattenLists(); + ColumnVector expected = ColumnVector.fromLists(listType, + Arrays.asList(1, 2, 3, 4, 5, 6), + Arrays.asList(7, 8, 9, 10, 11, 12, 13, 14, 15))) { + assertColumnsAreEqual(expected, result); + } + + // Input has nulls. + try (ColumnVector input = ColumnVector.fromLists(listOfListsType, + Arrays.asList(null, Arrays.asList(3), Arrays.asList(4, 5, 6)), + Arrays.asList(Arrays.asList(null, 8, 9), Arrays.asList(10, 11, 12, 13, 14, null)))) { + try (ColumnVector result = input.flattenLists(false); + ColumnVector expected = ColumnVector.fromLists(listType, + null, + Arrays.asList(null, 8, 9, 10, 11, 12, 13, 14, null))) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector result = input.flattenLists(true); + ColumnVector expected = ColumnVector.fromLists(listType, + Arrays.asList(3, 4, 5, 6), + Arrays.asList(null, 8, 9, 10, 11, 12, 13, 14, null))) { + assertColumnsAreEqual(expected, result); + } + } + } + @Test void testPrefixSum() { try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10);