diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index f91ee5535b1..cd826707de2 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2260,6 +2260,25 @@ public final ColumnVector extractListElement(int index) { return new ColumnVector(extractListElement(getNativeView(), index)); } + /** + * For each list in this column pull out the entry at the corresponding index specified in + * the index column. If the entry goes off the end of the list a NULL is returned instead. + * + * The index column should have the same row count with the list column. + * + * @param indices a column of 0 based offsets into the list. Negative values go backwards from + * the end of the list. + * @return a new column of the values at those indexes. + */ + public final ColumnVector extractListElement(ColumnView indices) { + assert type.equals(DType.LIST) : "A column of type LIST is required for .extractListElement()"; + assert indices != null && DType.INT32.equals(indices.type) + : "indices should be non-null and integer type"; + assert indices.getRowCount() == rows + : "indices must have the same row count with list column"; + return new ColumnVector(extractListElementV(getNativeView(), indices.getNativeView())); + } + /** * Create a new LIST column by copying elements from the current LIST column ignoring duplicate, * producing a LIST column in which each list contain only unique elements. @@ -3752,6 +3771,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long extractListElement(long nativeView, int index); + private static native long extractListElementV(long nativeView, long indicesView); + private static native long dropListDuplicates(long nativeView); private static native long dropListDuplicatesWithKeysValues(long nativeHandle); diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 548844aa0d3..a69c7c29900 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -392,6 +392,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElement(JNIEnv CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElementV(JNIEnv *env, jclass, + jlong column_view, + jlong indices_view) { + JNI_NULL_CHECK(env, column_view, "column is null", 0); + JNI_NULL_CHECK(env, indices_view, "indices is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *indices = reinterpret_cast(indices_view); + cudf::column_view *cv = reinterpret_cast(column_view); + cudf::lists_column_view lcv(*cv); + return release_as_jlong(cudf::lists::extract_list_element(lcv, *indices)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicates(JNIEnv *env, jclass, jlong column_view) { JNI_NULL_CHECK(env, column_view, "column is null", 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index b759c746735..0ba29840156 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4360,17 +4360,23 @@ void testsubstring() { @Test void testExtractListElements() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - ColumnVector expected = ColumnVector.fromStrings("Héllo", - "thésé", - null, - "", - "ARé", - "test"); - ColumnVector tmp = v.stringSplitRecord(" "); - ColumnVector result = tmp.extractListElement(0)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test"); + ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(0)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testExtractListElementsV() { + try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1); + ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings"); + ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(indices)) { + assertColumnsAreEqual(expected, result); + } } @Test