Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
Add JNI for lists::concatenate_list_elements (rapidsai#13547)
Browse files Browse the repository at this point in the history
This implements JNI work for `lists::concatenate_list_elements` to expose the API to Java usage.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: rapidsai#13547
  • Loading branch information
ttnghia authored Jun 14, 2023
1 parent 02be87b commit 67efaf6
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 0 deletions.
27 changes: 27 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,31 @@ public final ColumnVector dropListDuplicatesWithKeysValues() {
return new ColumnVector(dropListDuplicatesWithKeysValues(getNativeView()));
}

/**
* Flatten each list of lists into a single list.
*
* The column must have rows that are lists of lists.
* Any row containing null list elements will result in a null output row.
*
* @return A new column vector containing the flattened result
*/
public ColumnVector flattenLists() {
return flattenLists(false);
}

/**
* Flatten each list of lists into a single list.
*
* The column must have rows that are lists of lists.
*
* @param ignoreNull Whether to ignore null list elements in the input column from the operation,
* or any row containing null list elements will result in a null output row
* @return A new column vector containing the flattened result
*/
public ColumnVector flattenLists(boolean ignoreNull) {
return new ColumnVector(flattenLists(getNativeView(), ignoreNull));
}

/////////////////////////////////////////////////////////////////////////////
// STRINGS
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -4467,6 +4492,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long dropListDuplicatesWithKeysValues(long nativeHandle);

private static native long flattenLists(long inputHandle, boolean ignoreNull);

/**
* Native method for list lookup
* @param nativeView the column view handle of the list
Expand Down
15 changes: 15 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cudf/detail/null_mask.hpp>
#include <cudf/filling.hpp>
#include <cudf/hashing.hpp>
#include <cudf/lists/combine.hpp>
#include <cudf/lists/contains.hpp>
#include <cudf/lists/count_elements.hpp>
#include <cudf/lists/detail/concatenate.hpp>
Expand Down Expand Up @@ -504,6 +505,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_flattenLists(JNIEnv *env, jclass,
jlong input_handle,
jboolean ignore_null) {
JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE :
cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW;
auto const input_cv = reinterpret_cast<cudf::column_view const *>(input_handle);
return release_as_jlong(cudf::lists::concatenate_list_elements(*input_cv, null_policy));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContains(JNIEnv *env, jclass,
jlong column_view,
jlong lookup_key) {
Expand Down
36 changes: 36 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2884,6 +2884,42 @@ void testListConcatByRowIgnoreNull() {
}
}

@Test
void testFlattenLists() {
HostColumnVector.ListType listType = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
HostColumnVector.ListType listOfListsType = new HostColumnVector.ListType(true, listType);

// Input does not have nulls.
try (ColumnVector input = ColumnVector.fromLists(listOfListsType,
Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3), Arrays.asList(4, 5, 6)),
Arrays.asList(Arrays.asList(7, 8, 9), Arrays.asList(10, 11, 12, 13, 14, 15)));
ColumnVector result = input.flattenLists();
ColumnVector expected = ColumnVector.fromLists(listType,
Arrays.asList(1, 2, 3, 4, 5, 6),
Arrays.asList(7, 8, 9, 10, 11, 12, 13, 14, 15))) {
assertColumnsAreEqual(expected, result);
}

// Input has nulls.
try (ColumnVector input = ColumnVector.fromLists(listOfListsType,
Arrays.asList(null, Arrays.asList(3), Arrays.asList(4, 5, 6)),
Arrays.asList(Arrays.asList(null, 8, 9), Arrays.asList(10, 11, 12, 13, 14, null)))) {
try (ColumnVector result = input.flattenLists(false);
ColumnVector expected = ColumnVector.fromLists(listType,
null,
Arrays.asList(null, 8, 9, 10, 11, 12, 13, 14, null))) {
assertColumnsAreEqual(expected, result);
}
try (ColumnVector result = input.flattenLists(true);
ColumnVector expected = ColumnVector.fromLists(listType,
Arrays.asList(3, 4, 5, 6),
Arrays.asList(null, 8, 9, 10, 11, 12, 13, 14, null))) {
assertColumnsAreEqual(expected, result);
}
}
}

@Test
void testPrefixSum() {
try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10);
Expand Down

0 comments on commit 67efaf6

Please sign in to comment.