diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 4d9991d0dd9..53a02d83dd1 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2160,6 +2160,10 @@ public final ColumnVector extractListElement(int index) { return new ColumnVector(extractListElement(getNativeView(), index)); } + public final ColumnVector dropListDuplicates() { + return new ColumnVector(dropListDuplicates(getNativeView())); + } + ///////////////////////////////////////////////////////////////////////////// // STRINGS ///////////////////////////////////////////////////////////////////////////// @@ -3489,6 +3493,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long extractListElement(long nativeView, int index); + private static native long dropListDuplicates(long nativeView); + /** * Native method for list lookup * @param nativeView the column view handle of the list diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 71fbc0fd384..adc0de12f25 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -395,6 +396,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractListElement(JNIEnv CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicates(JNIEnv *env, jclass, + jlong column_view) { + JNI_NULL_CHECK(env, column_view, "column is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view const *cv = reinterpret_cast(column_view); + cudf::lists_column_view lcv(*cv); + + std::unique_ptr ret = cudf::lists::drop_list_duplicates(lcv); + return reinterpret_cast(ret.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listContains(JNIEnv *env, jclass, jlong column_view, jlong lookup_key) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 4856071e296..0643776a546 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4193,6 +4193,28 @@ void testExtractListElements() { } } + @Test + void testDropListDuplicates() { + List list1 = Arrays.asList(1, 2); + List list2 = Arrays.asList(3, 4, 5); + List list3 = Arrays.asList(null, 0, 6, 6, 0); + List dedupeList3 = Arrays.asList(0, 6, null); + List list4 = Arrays.asList(null, 6, 7, null, 7); + List dedupeList4 = Arrays.asList(6, 7, null); + List list5 = null; + + HostColumnVector.DataType listType = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + try (ColumnVector v = ColumnVector.fromLists(listType, list1, list2, list3, list4, list5); + ColumnVector expected = ColumnVector.fromLists(listType, list1, list2, dedupeList3, dedupeList4, list5); + ColumnVector tmp = v.dropListDuplicates(); + // Note dropping duplicates does not have any ordering guarantee, so sort to make it all + // consistent + ColumnVector result = tmp.listSortRows(false, false)) { + assertColumnsAreEqual(expected, result); + } + } + @Test void testListContainsString() { List list1 = Arrays.asList("Héllo there", "thésé");