Skip to content

Commit

Permalink
Add support for scatter() on lists-of-struct columns (#6817)
Browse files Browse the repository at this point in the history
Addendum to #6768, to support scatter on columns of type list<struct>. The hard part was the tests.
  • Loading branch information
mythrocks authored Nov 26, 2020
1 parent f3b0e06 commit c34d9bf
Show file tree
Hide file tree
Showing 3 changed files with 575 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
- PR #6768 Add support for scatter() on list columns
- PR #6796 Add create_metadata_file in dask_cudf
- PR #6765 Cupy fallback for __array_function__ and __array_ufunc__ for cudf.Series
- PR #6817 Add support for scatter() on lists-of-struct columns
- PR #6805 Implement `cudf::detail::copy_if` for `decimal32` and `decimal64`

## Improvements

Expand Down
97 changes: 96 additions & 1 deletion cpp/include/cudf/lists/detail/scatter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ struct list_child_constructor {
template <typename T>
struct is_supported_child_type {
static const bool value = cudf::is_fixed_width<T>() || std::is_same<T, string_view>::value ||
std::is_same<T, list_view>::value;
std::is_same<T, list_view>::value ||
std::is_same<T, struct_view>::value;
};

public:
Expand Down Expand Up @@ -617,6 +618,100 @@ struct list_child_constructor {
stream.value(),
mr);
}

/**
* @brief (Recursively) constructs child columns that are structs.
*/
template <typename T>
std::enable_if_t<std::is_same<T, struct_view>::value, std::unique_ptr<column>> operator()(
rmm::device_uvector<unbound_list_view> const& list_vector,
cudf::column_view const& list_offsets,
cudf::lists_column_view const& source_lists_column_view,
cudf::lists_column_view const& target_lists_column_view,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
auto const source_column_device_view =
column_device_view::create(source_lists_column_view.parent(), stream);
auto const target_column_device_view =
column_device_view::create(target_lists_column_view.parent(), stream);
auto const source_lists = cudf::detail::lists_column_device_view(*source_column_device_view);
auto const target_lists = cudf::detail::lists_column_device_view(*target_column_device_view);

auto const source_structs = source_lists_column_view.child();
auto const target_structs = target_lists_column_view.child();

auto const num_child_rows = get_num_child_rows(list_offsets, stream);

auto const num_struct_members =
std::distance(source_structs.child_begin(), source_structs.child_end());
std::vector<std::unique_ptr<column>> child_columns;
child_columns.reserve(num_struct_members);

auto project_member_as_list = [stream, mr](column_view const& structs_member,
cudf::size_type const& structs_list_num_rows,
column_view const& structs_list_offsets,
rmm::device_buffer const& structs_list_nullmask,
cudf::size_type const& structs_list_null_count) {
return cudf::make_lists_column(structs_list_num_rows,
std::make_unique<column>(structs_list_offsets, stream, mr),
std::make_unique<column>(structs_member, stream, mr),
structs_list_null_count,
rmm::device_buffer(structs_list_nullmask),
stream,
mr);
};

auto const iter_source_member_as_list = thrust::make_transform_iterator(
thrust::make_counting_iterator<cudf::size_type>(0), [&](auto child_idx) {
return project_member_as_list(
source_structs.child(child_idx),
source_lists_column_view.size(),
source_lists_column_view.offsets(),
cudf::detail::copy_bitmask(source_lists_column_view.parent(), stream, mr),
source_lists_column_view.null_count());
});

auto const iter_target_member_as_list = thrust::make_transform_iterator(
thrust::make_counting_iterator<cudf::size_type>(0), [&](auto child_idx) {
return project_member_as_list(
target_structs.child(child_idx),
target_lists_column_view.size(),
target_lists_column_view.offsets(),
cudf::detail::copy_bitmask(target_lists_column_view.parent(), stream, mr),
target_lists_column_view.null_count());
});

std::transform(
iter_source_member_as_list,
iter_source_member_as_list + num_struct_members,
iter_target_member_as_list,
std::back_inserter(child_columns),
[&](auto source_struct_member_as_list, auto target_struct_member_as_list) {
return cudf::type_dispatcher(
source_struct_member_as_list->child(cudf::lists_column_view::child_column_index).type(),
list_child_constructor{},
list_vector,
list_offsets,
cudf::lists_column_view(source_struct_member_as_list->view()),
cudf::lists_column_view(target_struct_member_as_list->view()),
stream,
mr);
});

auto child_null_mask =
source_lists_column_view.child().nullable() || target_lists_column_view.child().nullable()
? construct_child_nullmask(
list_vector, list_offsets, source_lists, target_lists, num_child_rows, stream, mr)
: std::make_pair(rmm::device_buffer{}, 0);

return cudf::make_structs_column(num_child_rows,
std::move(child_columns),
child_null_mask.second,
std::move(child_null_mask.first),
stream.value(),
mr);
}
};

/**
Expand Down
Loading

0 comments on commit c34d9bf

Please sign in to comment.