diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 2f3f2bf80cf..e50a9e86ead 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -256,6 +256,15 @@ public final ColumnVector getByteCount() { return new ColumnVector(byteCount(getNativeView())); } + /** + * Get the number of elements for each list. Null lists will have a value of null. + * @return the number of elements in each list as an INT32 value. + */ + public final ColumnVector countElements() { + assert DType.LIST.equals(type) : "Only lists are supported"; + return new ColumnVector(countElements(getNativeView())); + } + /** * Returns a Boolean vector with the same number of rows as this instance, that has * TRUE for any entry that is not null, and FALSE for any null entry (as per the validity mask) @@ -2749,6 +2758,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long binaryOpVV(long lhs, long rhs, int op, int dtype, int scale); + private static native long countElements(long viewHandle); + private static native long byteCount(long viewHandle) throws CudfException; private static native long extractListElement(long nativeView, int index); diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index ac14e1605d7..73db5ee4df3 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -20,16 +20,17 @@ #include #include #include -#include -#include -#include #include #include #include +#include +#include +#include #include #include #include #include +#include #include #include #include @@ -430,6 +431,19 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_split(JNIEnv *env, j CATCH_STD(env, NULL); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_countElements(JNIEnv *env, jclass clazz, + jlong view_handle) { + JNI_NULL_CHECK(env, view_handle, "input column is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *n_column = reinterpret_cast(view_handle); + std::unique_ptr result = + cudf::lists::count_elements(cudf::lists_column_view(*n_column)); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_charLengths(JNIEnv *env, jclass clazz, jlong view_handle) { JNI_NULL_CHECK(env, view_handle, "input column is null", 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index d224543e574..420e176efe2 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1666,6 +1666,18 @@ void testAppendStrings() { } } + @Test + void testCountElements() { + DataType dt = new ListType(true, new BasicType(true, DType.INT32)); + try (ColumnVector cv = ColumnVector.fromLists(dt, Arrays.asList(1), + Arrays.asList(1, 2), null, Arrays.asList(null, null), + Arrays.asList(1, 2, 3), Arrays.asList(1, 2, 3, 4)); + ColumnVector lengths = cv.countElements(); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 2, 3, 4)) { + TableTest.assertColumnsAreEqual(expected, lengths); + } + } + @Test void testStringLengths() { try (ColumnVector cv = ColumnVector.fromStrings("1", "12", null, "123", "1234");