Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed Mar 8, 2022
1 parent 4547e9a commit aea1704
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
5 changes: 3 additions & 2 deletions cpp/include/cudf/detail/reduction_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <cudf/column/column_view.hpp>
#include <cudf/scalar/scalar.hpp>

#include "cudf/lists/lists_column_view.hpp"
#include <rmm/cuda_stream_view.hpp>

namespace cudf {
Expand Down Expand Up @@ -278,7 +279,7 @@ std::unique_ptr<scalar> collect_list(
* @return merged list as scalar
*/
std::unique_ptr<scalar> merge_lists(
column_view const& col,
lists_column_view const& col,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -312,7 +313,7 @@ std::unique_ptr<scalar> collect_set(
* @return collected list with unique elements as scalar
*/
std::unique_ptr<scalar> merge_sets(
column_view const& col,
lists_column_view const& col,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
Expand Down
12 changes: 4 additions & 8 deletions cpp/src/reductions/collect_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,11 @@ std::unique_ptr<scalar> collect_list(column_view const& col,
}
}

std::unique_ptr<scalar> merge_lists(column_view const& col,
std::unique_ptr<scalar> merge_lists(lists_column_view const& col,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(col.type().id() == type_id::LIST,
"input column of merge_lists must be a list column");
auto flatten_col = lists_column_view(col).get_sliced_child(stream);
auto flatten_col = col.get_sliced_child(stream);
return make_list_scalar(flatten_col, stream, mr);
}

Expand All @@ -79,15 +77,13 @@ std::unique_ptr<scalar> collect_set(column_view const& col,
return drop_duplicates(*ls, nulls_equal, nans_equal, stream, mr);
}

std::unique_ptr<scalar> merge_sets(column_view const& col,
std::unique_ptr<scalar> merge_sets(lists_column_view const& col,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(col.type().id() == type_id::LIST,
"input column of merge_lists must be a list column");
auto flatten_col = lists_column_view(col).get_sliced_child(stream);
auto flatten_col = col.get_sliced_child(stream);
auto scalar = std::make_unique<list_scalar>(flatten_col, true, stream, mr);
return drop_duplicates(*scalar, nulls_equal, nans_equal, stream, mr);
}
Expand Down

0 comments on commit aea1704

Please sign in to comment.