Skip to content

Commit

Permalink
Add JNI for extract_list_element with index column (#10341)
Browse files Browse the repository at this point in the history
This PR is to expose the `cudf::list::extract_list_element(column_view, column_view)` API to JNI, along with its tests.

closes #10340

Signed-off-by: Firestarman <[email protected]>

Authors:
  - Liangcai Li (https://github.com/firestarman)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #10341
  • Loading branch information
firestarman authored Feb 22, 2022
1 parent 58810af commit add6990
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
21 changes: 21 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2260,6 +2260,25 @@ public final ColumnVector extractListElement(int index) {
return new ColumnVector(extractListElement(getNativeView(), index));
}

/**
* For each list in this column pull out the entry at the corresponding index specified in
* the index column. If the entry goes off the end of the list a NULL is returned instead.
*
* The index column should have the same row count with the list column.
*
* @param indices a column of 0 based offsets into the list. Negative values go backwards from
* the end of the list.
* @return a new column of the values at those indexes.
*/
public final ColumnVector extractListElement(ColumnView indices) {
assert type.equals(DType.LIST) : "A column of type LIST is required for .extractListElement()";
assert indices != null && DType.INT32.equals(indices.type)
: "indices should be non-null and integer type";
assert indices.getRowCount() == rows
: "indices must have the same row count with list column";
return new ColumnVector(extractListElementV(getNativeView(), indices.getNativeView()));
}

/**
* Create a new LIST column by copying elements from the current LIST column ignoring duplicate,
* producing a LIST column in which each list contain only unique elements.
Expand Down Expand Up @@ -3752,6 +3771,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long extractListElement(long nativeView, int index);

private static native long extractListElementV(long nativeView, long indicesView);

private static native long dropListDuplicates(long nativeView);

private static native long dropListDuplicatesWithKeysValues(long nativeHandle);
Expand Down
15 changes: 15 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElement(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElementV(JNIEnv *env, jclass,
jlong column_view,
jlong indices_view) {
JNI_NULL_CHECK(env, column_view, "column is null", 0);
JNI_NULL_CHECK(env, indices_view, "indices is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *indices = reinterpret_cast<cudf::column_view *>(indices_view);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(column_view);
cudf::lists_column_view lcv(*cv);
return release_as_jlong(cudf::lists::extract_list_element(lcv, *indices));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicates(JNIEnv *env, jclass,
jlong column_view) {
JNI_NULL_CHECK(env, column_view, "column is null", 0);
Expand Down
28 changes: 17 additions & 11 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4360,17 +4360,23 @@ void testsubstring() {

@Test
void testExtractListElements() {
try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings");
ColumnVector expected = ColumnVector.fromStrings("Héllo",
"thésé",
null,
"",
"ARé",
"test");
ColumnVector tmp = v.stringSplitRecord(" ");
ColumnVector result = tmp.extractListElement(0)) {
assertColumnsAreEqual(expected, result);
}
try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings");
ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test");
ColumnVector list = v.stringSplitRecord(" ");
ColumnVector result = list.extractListElement(0)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testExtractListElementsV() {
try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings");
ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1);
ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings");
ColumnVector list = v.stringSplitRecord(" ");
ColumnVector result = list.extractListElement(indices)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
Expand Down

0 comments on commit add6990

Please sign in to comment.