Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Scatter struct_scalar #8630

Merged
merged 11 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cpp/include/cudf/detail/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ std::unique_ptr<table> sample(
*
* @param[in] stream CUDA stream used for device memory operations and kernel launches.
*/
std::unique_ptr<scalar> get_element(column_view const& input,
size_type index,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
std::unique_ptr<scalar> get_element(
column_view const& input,
size_type index,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
} // namespace detail
} // namespace cudf
20 changes: 20 additions & 0 deletions cpp/include/cudf/detail/null_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ void set_null_mask(bitmask_type *bitmask,
bool valid,
rmm::cuda_stream_view stream = rmm::cuda_stream_default);

/**
* @copydoc cudf::count_set_bits
*
* @param[in] stream CUDA stream used for device memory operations and kernel launches.
*/
cudf::size_type count_set_bits(bitmask_type const *bitmask,
size_type start,
size_type stop,
rmm::cuda_stream_view stream);

/**
* @copydoc cudf::count_unset_bits
*
* @param[in] stream CUDA stream used for device memory operations and kernel launches.
*/
cudf::size_type count_unset_bits(bitmask_type const *bitmask,
size_type start,
size_type stop,
rmm::cuda_stream_view stream);

/**
* @copydoc cudf::segmented_count_set_bits
*
Expand Down
126 changes: 88 additions & 38 deletions cpp/src/copying/scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/dictionary/detail/search.hpp>
#include <cudf/lists/list_view.cuh>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/detail/scatter.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/structs/struct_view.hpp>
Expand Down Expand Up @@ -63,32 +64,32 @@ __global__ void marking_bitmask_kernel(mutable_column_device_view destination,
}

template <typename MapIterator>
void scatter_scalar_bitmask(std::vector<std::reference_wrapper<const scalar>> const& source,
MapIterator scatter_map,
size_type num_scatter_rows,
std::vector<std::unique_ptr<column>>& target,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
void scatter_scalar_bitmask_inplace(std::reference_wrapper<const scalar> const& source,
MapIterator scatter_map,
size_type num_scatter_rows,
column& target,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
constexpr size_type block_size = 256;
size_type const grid_size = grid_1d(num_scatter_rows, block_size).num_blocks;

for (size_t i = 0; i < target.size(); ++i) {
auto const source_is_valid = source[i].get().is_valid(stream);
if (target[i]->nullable() or not source_is_valid) {
if (not target[i]->nullable()) {
// Target must have a null mask if the source is not valid
auto mask = detail::create_null_mask(target[i]->size(), mask_state::ALL_VALID, stream, mr);
target[i]->set_null_mask(std::move(mask), 0);
}

auto target_view = mutable_column_device_view::create(target[i]->mutable_view(), stream);

auto bitmask_kernel = source_is_valid ? marking_bitmask_kernel<true, decltype(scatter_map)>
: marking_bitmask_kernel<false, decltype(scatter_map)>;
bitmask_kernel<<<grid_size, block_size, 0, stream.value()>>>(
*target_view, scatter_map, num_scatter_rows);
auto const source_is_valid = source.get().is_valid(stream);
if (target.nullable() or not source_is_valid) {
if (not target.nullable()) {
// Target must have a null mask if the source is not valid
auto mask = detail::create_null_mask(target.size(), mask_state::ALL_VALID, stream, mr);
target.set_null_mask(std::move(mask), 0);
}

auto target_view = mutable_column_device_view::create(target, stream);

auto bitmask_kernel = source_is_valid ? marking_bitmask_kernel<true, decltype(scatter_map)>
: marking_bitmask_kernel<false, decltype(scatter_map)>;
bitmask_kernel<<<grid_size, block_size, 0, stream.value()>>>(
*target_view, scatter_map, num_scatter_rows);

target.set_null_count(count_unset_bits(target.view().null_mask(), 0, target.size(), stream));
}
}

Expand All @@ -103,6 +104,7 @@ struct column_scalar_scatterer_impl {
{
CUDF_EXPECTS(source.get().type() == target.type(), "scalar and column types must match");

// make a copy of data and nullmask from source
isVoid marked this conversation as resolved.
Show resolved Hide resolved
auto result = std::make_unique<column>(target, stream, mr);
auto result_view = result->mutable_view();

Expand All @@ -117,6 +119,7 @@ struct column_scalar_scatterer_impl {
scatter_iter,
result_view.begin<Element>());

scatter_scalar_bitmask_inplace(source, scatter_iter, scatter_rows, *result, stream, mr);
return result;
}
};
Expand All @@ -136,7 +139,10 @@ struct column_scalar_scatterer_impl<string_view, MapIterator> {
auto const source_view = string_view(scalar_impl->data(), scalar_impl->size());
auto const begin = thrust::make_constant_iterator(source_view);
auto const end = begin + scatter_rows;
return strings::detail::scatter(begin, end, scatter_iter, target, stream, mr);
auto result = strings::detail::scatter(begin, end, scatter_iter, target, stream, mr);

scatter_scalar_bitmask_inplace(source, scatter_iter, scatter_rows, *result, stream, mr);
return result;
}
};

Expand All @@ -149,17 +155,11 @@ struct column_scalar_scatterer_impl<list_view, MapIterator> {
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
return lists::detail::scatter(
source, scatter_iter, scatter_iter + scatter_rows, target, stream, mr);
}
};
auto result =
lists::detail::scatter(source, scatter_iter, scatter_iter + scatter_rows, target, stream, mr);

template <typename MapIterator>
struct column_scalar_scatterer_impl<struct_view, MapIterator> {
template <typename... Args>
std::unique_ptr<column> operator()(Args&&...) const
{
CUDF_FAIL("scatter scalar to struct_view not implemented");
scatter_scalar_bitmask_inplace(source, scatter_iter, scatter_rows, *result, stream, mr);
return result;
}
};

Expand Down Expand Up @@ -200,10 +200,13 @@ struct column_scalar_scatterer_impl<dictionary32, MapIterator> {
// use the keys from the matched column
std::unique_ptr<column> keys_column(std::move(dict_target->release().children.back()));
// create the output column
return make_dictionary_column(std::move(keys_column),
std::move(indices_column),
std::move(*(contents.null_mask.release())),
null_count);
auto result = make_dictionary_column(std::move(keys_column),
std::move(indices_column),
std::move(*(contents.null_mask.release())),
null_count);

scatter_scalar_bitmask_inplace(source, scatter_iter, scatter_rows, *result, stream, mr);
return result;
}
};

Expand All @@ -222,6 +225,55 @@ struct column_scalar_scatterer {
}
};

template <typename MapIterator>
struct column_scalar_scatterer_impl<struct_view, MapIterator> {
std::unique_ptr<column> operator()(std::reference_wrapper<const scalar> const& source,
MapIterator scatter_iter,
size_type scatter_rows,
column_view const& target,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
// For each field of `source`, copy construct a scalar from the field
// and dispatch the correct scalar scatterer

auto typed_s = static_cast<struct_scalar const*>(&source.get());
size_type n_fields = typed_s->view().num_columns();
isVoid marked this conversation as resolved.
Show resolved Hide resolved
CUDF_EXPECTS(n_fields == target.num_children(), "Mismatched number of fields.");

auto scatter_functor = column_scalar_scatterer<decltype(scatter_iter)>{};
std::vector<std::unique_ptr<column>> fields;
std::transform(thrust::make_counting_iterator(0),
thrust::make_counting_iterator(n_fields),
std::back_inserter(fields),
[&](auto const& i) {
auto row_slr = get_element(typed_s->view().column(i), 0, stream);
isVoid marked this conversation as resolved.
Show resolved Hide resolved
return type_dispatcher<dispatch_storage_type>(row_slr->type(),
scatter_functor,
*row_slr,
scatter_iter,
scatter_rows,
target.child(i),
stream,
mr);
});

// Compute nullmask
rmm::device_buffer null_mask =
target.nullable() ? copy_bitmask(target, stream, mr)
: create_null_mask(target.size(), mask_state::UNALLOCATED, stream, mr);
column null_mask_stub(data_type{type_id::STRUCT}, target.size(), rmm::device_buffer{});
null_mask_stub.set_null_mask(std::move(null_mask), target.null_count());
scatter_scalar_bitmask_inplace(source, scatter_iter, scatter_rows, null_mask_stub, stream, mr);
size_type null_count = null_mask_stub.null_count();
auto contents = null_mask_stub.release();

// Nullmask pushdown inside factory method
return make_structs_column(
target.size(), std::move(fields), null_count, std::move(*contents.null_mask));
}
};

} // namespace

std::unique_ptr<table> scatter(table_view const& source,
Expand Down Expand Up @@ -305,8 +357,6 @@ std::unique_ptr<table> scatter(std::vector<std::reference_wrapper<const scalar>>
mr);
});

scatter_scalar_bitmask(source, scatter_iter, scatter_rows, result, stream, mr);

return std::make_unique<table>(std::move(result));
}

Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ ConfigureTest(COPYING_TEST
copying/scatter_list_tests.cpp
copying/scatter_list_scalar_tests.cpp
copying/scatter_struct_tests.cpp
copying/scatter_struct_scalar_tests.cpp
copying/segmented_gather_list_tests.cpp
copying/shift_tests.cpp
copying/slice_tests.cpp
Expand Down
Loading