diff --git a/cpp/include/cudf/detail/reduction_functions.hpp b/cpp/include/cudf/detail/reduction_functions.hpp index 0fd57eba8d1..d8f23e8d7cb 100644 --- a/cpp/include/cudf/detail/reduction_functions.hpp +++ b/cpp/include/cudf/detail/reduction_functions.hpp @@ -19,6 +19,7 @@ #include #include +#include "cudf/lists/lists_column_view.hpp" #include namespace cudf { @@ -278,7 +279,7 @@ std::unique_ptr collect_list( * @return merged list as scalar */ std::unique_ptr merge_lists( - column_view const& col, + lists_column_view const& col, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -312,7 +313,7 @@ std::unique_ptr collect_set( * @return collected list with unique elements as scalar */ std::unique_ptr merge_sets( - column_view const& col, + lists_column_view const& col, null_equality nulls_equal, nan_equality nans_equal, rmm::cuda_stream_view stream = rmm::cuda_stream_default, diff --git a/cpp/src/reductions/collect_ops.cu b/cpp/src/reductions/collect_ops.cu index 07d3dbe04eb..c9bd06a1171 100644 --- a/cpp/src/reductions/collect_ops.cu +++ b/cpp/src/reductions/collect_ops.cu @@ -57,13 +57,11 @@ std::unique_ptr collect_list(column_view const& col, } } -std::unique_ptr merge_lists(column_view const& col, +std::unique_ptr merge_lists(lists_column_view const& col, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - CUDF_EXPECTS(col.type().id() == type_id::LIST, - "input column of merge_lists must be a list column"); - auto flatten_col = lists_column_view(col).get_sliced_child(stream); + auto flatten_col = col.get_sliced_child(stream); return make_list_scalar(flatten_col, stream, mr); } @@ -79,15 +77,13 @@ std::unique_ptr collect_set(column_view const& col, return drop_duplicates(*ls, nulls_equal, nans_equal, stream, mr); } -std::unique_ptr merge_sets(column_view const& col, +std::unique_ptr merge_sets(lists_column_view const& col, null_equality nulls_equal, nan_equality nans_equal, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - CUDF_EXPECTS(col.type().id() == type_id::LIST, - "input column of merge_lists must be a list column"); - auto flatten_col = lists_column_view(col).get_sliced_child(stream); + auto flatten_col = col.get_sliced_child(stream); auto scalar = std::make_unique(flatten_col, true, stream, mr); return drop_duplicates(*scalar, nulls_equal, nans_equal, stream, mr); }