diff --git a/python/cudf/cudf/_lib/lists.pyx b/python/cudf/cudf/_lib/lists.pyx index ceae1b148aa..76f37c3b845 100644 --- a/python/cudf/cudf/_lib/lists.pyx +++ b/python/cudf/cudf/_lib/lists.pyx @@ -8,9 +8,6 @@ from libcpp.utility cimport move from cudf._lib.column cimport Column from cudf._lib.pylibcudf.libcudf.column.column cimport column -from cudf._lib.pylibcudf.libcudf.lists.count_elements cimport ( - count_elements as cpp_count_elements, -) from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport ( lists_column_view, ) @@ -36,19 +33,10 @@ from cudf._lib.pylibcudf cimport Scalar @acquire_spill_lock() def count_elements(Column col): - - # shared_ptr required because lists_column_view has no default - # ctor - cdef shared_ptr[lists_column_view] list_view = ( - make_shared[lists_column_view](col.view()) + return Column.from_pylibcudf( + pylibcudf.lists.count_elements( + col.to_pylibcudf(mode="read")) ) - cdef unique_ptr[column] c_result - - with nogil: - c_result = move(cpp_count_elements(list_view.get()[0])) - - result = Column.from_unique_ptr(move(c_result)) - return result @acquire_spill_lock() diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/count_elements.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/count_elements.pxd index 38bdd4db0bb..ba57a839fbc 100644 --- a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/count_elements.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/count_elements.pxd @@ -9,4 +9,4 @@ from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport ( cdef extern from "cudf/lists/count_elements.hpp" namespace "cudf::lists" nogil: - cdef unique_ptr[column] count_elements(const lists_column_view) except + + cdef unique_ptr[column] count_elements(const lists_column_view&) except + diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pxd b/python/cudf/cudf/_lib/pylibcudf/lists.pxd index 38a479e4791..38eb575ee8d 100644 --- a/python/cudf/cudf/_lib/pylibcudf/lists.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/lists.pxd @@ -33,3 +33,5 @@ cpdef Column reverse(Column) cpdef Column segmented_gather(Column, Column) cpdef Column extract_list_element(Column, ColumnOrSizeType) + +cpdef Column count_elements(Column) diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pyx b/python/cudf/cudf/_lib/pylibcudf/lists.pyx index 19c961aa014..ea469642dd5 100644 --- a/python/cudf/cudf/_lib/pylibcudf/lists.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/lists.pyx @@ -17,6 +17,9 @@ from cudf._lib.pylibcudf.libcudf.lists.combine cimport ( concatenate_null_policy, concatenate_rows as cpp_concatenate_rows, ) +from cudf._lib.pylibcudf.libcudf.lists.count_elements cimport ( + count_elements as cpp_count_elements, +) from cudf._lib.pylibcudf.libcudf.lists.extract cimport ( extract_list_element as cpp_extract_list_element, ) @@ -293,3 +296,27 @@ cpdef Column extract_list_element(Column input, ColumnOrSizeType index): index.view() if ColumnOrSizeType is Column else index, )) return Column.from_libcudf(move(c_result)) + + +cpdef Column count_elements(Column input): + """Count the number of rows in each + list element in the given lists column. + For details, see :cpp:func:`count_elements`. + + Parameters + ---------- + input : Column + The input column + + Returns + ------- + Column + A new Column of the lengths of each list element + """ + cdef ListColumnView list_view = input.list_view() + cdef unique_ptr[column] c_result + + with nogil: + c_result = move(cpp_count_elements(list_view.view())) + + return Column.from_libcudf(move(c_result)) diff --git a/python/cudf/cudf/pylibcudf_tests/test_lists.py b/python/cudf/cudf/pylibcudf_tests/test_lists.py index 07ecaed5012..7cfed884f90 100644 --- a/python/cudf/cudf/pylibcudf_tests/test_lists.py +++ b/python/cudf/cudf/pylibcudf_tests/test_lists.py @@ -181,3 +181,13 @@ def test_extract_list_element_column(test_data): expect = pa.array([0, None, None, 7]) assert_column_eq(expect, res) + + +def test_count_elements(test_data): + arr = pa.array(test_data[0][1]) + plc_column = plc.interop.from_arrow(arr) + res = plc.lists.count_elements(plc_column) + + expect = pa.array([1, 1, 0, 3], type=pa.int32()) + + assert_column_eq(expect, res)