Skip to content

Commit

Permalink
Implement scatter for struct columns (#7752)
Browse files Browse the repository at this point in the history
This PR implements scattering for struct columns. It also fixes #7638 when calling `cudf::partitioning` on struct data due to wrongly scattering data inside the partitioning kernels.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Conor Hoekstra (https://github.com/codereport)
  - Keith Kraus (https://github.com/kkraus14)
  - Jake Hemstad (https://github.com/jrhemstad)

URL: #7752
  • Loading branch information
ttnghia authored Apr 1, 2021
1 parent 4f30fcd commit e7153bb
Show file tree
Hide file tree
Showing 5 changed files with 498 additions and 40 deletions.
37 changes: 23 additions & 14 deletions cpp/include/cudf/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -470,15 +470,20 @@ struct column_gatherer_impl<struct_view> {
mr);
});

gather_bitmask(
// Table view of struct column.
cudf::table_view{
std::vector<cudf::column_view>{structs_column.child_begin(), structs_column.child_end()}},
gather_map_begin,
output_struct_members,
nullify_out_of_bounds ? gather_bitmask_op::NULLIFY : gather_bitmask_op::DONT_CHECK,
stream,
mr);
auto const nullable = std::any_of(structs_column.child_begin(),
structs_column.child_end(),
[](auto const& col) { return col.nullable(); });
if (nullable) {
gather_bitmask(
// Table view of struct column.
cudf::table_view{
std::vector<cudf::column_view>{structs_column.child_begin(), structs_column.child_end()}},
gather_map_begin,
output_struct_members,
nullify_out_of_bounds ? gather_bitmask_op::NULLIFY : gather_bitmask_op::DONT_CHECK,
stream,
mr);
}

return cudf::make_structs_column(
gather_map_size,
Expand Down Expand Up @@ -656,11 +661,15 @@ std::unique_ptr<table> gather(
mr));
}

gather_bitmask_op const op = bounds_policy == out_of_bounds_policy::NULLIFY
? gather_bitmask_op::NULLIFY
: gather_bitmask_op::DONT_CHECK;

gather_bitmask(source_table, gather_map_begin, destination_columns, op, stream, mr);
auto const nullable = bounds_policy == out_of_bounds_policy::NULLIFY ||
std::any_of(source_table.begin(), source_table.end(), [](auto const& col) {
return col.nullable();
});
if (nullable) {
auto const op = bounds_policy == out_of_bounds_policy::NULLIFY ? gather_bitmask_op::NULLIFY
: gather_bitmask_op::DONT_CHECK;
gather_bitmask(source_table, gather_map_begin, destination_columns, op, stream, mr);
}

return std::make_unique<table>(std::move(destination_columns));
}
Expand Down
156 changes: 141 additions & 15 deletions cpp/include/cudf/detail/scatter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/dictionary/dictionary_factories.hpp>
#include <cudf/lists/detail/scatter.cuh>
#include <cudf/null_mask.hpp>
#include <cudf/strings/detail/scatter.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/utilities/traits.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/uninitialized_fill.h>

namespace cudf {
namespace detail {

Expand All @@ -42,10 +45,9 @@ namespace detail {
* function using the PASSTHROUGH op since the resulting map may contain index
* values outside the target's range.
*
* First, the gather-map is initialized with invalid entries.
* The gather_rows is used since it should always be outside the target size.
*
* Then, the `output[scatter_map[i]] = i`.
* First, the gather-map is initialized with an invalid index.
* The value `numeric_limits::lowest()` is used since it should always be outside the target size.
* Then, `output[scatter_map[i]] = i` for each `i`.
*
* @tparam MapIterator Iterator type of the input scatter map.
* @param scatter_map_begin Beginning of scatter map.
Expand All @@ -62,11 +64,16 @@ auto scatter_to_gather(MapIterator scatter_map_begin,
{
using MapValueType = typename thrust::iterator_traits<MapIterator>::value_type;

// The gather_map is initialized with gather_rows value to identify pass-through entries
// when calling the gather_bitmask() which applies a pass-through whenever it finds a
// The gather_map is initialized with `numeric_limits::lowest()` value to identify pass-through
// entries when calling the gather_bitmask() which applies a pass-through whenever it finds a
// value outside the range of the target column.
// We'll use the gather_rows value for this since it should always be outside the valid range.
auto gather_map = rmm::device_vector<size_type>(gather_rows, gather_rows);
// We'll use the `numeric_limits::lowest()` value for this since it should always be outside the
// valid range.
auto gather_map = rmm::device_uvector<size_type>(gather_rows, stream);
thrust::uninitialized_fill(rmm::exec_policy(stream),
gather_map.begin(),
gather_map.end(),
std::numeric_limits<size_type>::lowest());

// Convert scatter map to a gather map
thrust::scatter(
Expand All @@ -79,6 +86,39 @@ auto scatter_to_gather(MapIterator scatter_map_begin,
return gather_map;
}

/**
* @brief Create a complement map of `scatter_to_gather` map
*
* The purpose of this map is to create an identity-mapping for the rows that are not
* touched by the `scatter_map`.
*
* The output result of this mapping is firstly initialized as an identity-mapping
* (i.e., `output[i] = i`). Then, for each value `idx` from `scatter_map`, the value `output[idx]`
* is set to `numeric_limits::lowest()`, which is an invalid, out-of-bound index to identify the
* pass-through entries when calling the `gather_bitmask()` function.
*
*/
template <typename MapIterator>
auto scatter_to_gather_complement(MapIterator scatter_map_begin,
MapIterator scatter_map_end,
size_type gather_rows,
rmm::cuda_stream_view stream)
{
auto gather_map = rmm::device_uvector<size_type>(gather_rows, stream);
thrust::sequence(rmm::exec_policy(stream), gather_map.begin(), gather_map.end(), 0);

auto const out_of_bounds_begin =
thrust::make_constant_iterator(std::numeric_limits<size_type>::lowest());
auto const out_of_bounds_end =
out_of_bounds_begin + thrust::distance(scatter_map_begin, scatter_map_end);
thrust::scatter(rmm::exec_policy(stream),
out_of_bounds_begin,
out_of_bounds_end,
scatter_map_begin,
gather_map.begin());
return gather_map;
}

template <typename Element, typename Enable = void>
struct column_scatterer_impl {
template <typename... Args>
Expand Down Expand Up @@ -218,6 +258,89 @@ struct column_scatterer {
}
};

template <>
struct column_scatterer_impl<struct_view> {
template <typename MapItRoot>
std::unique_ptr<column> operator()(column_view const& source,
MapItRoot scatter_map_begin,
MapItRoot scatter_map_end,
column_view const& target,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
CUDF_EXPECTS(source.num_children() == target.num_children(),
"Scatter source and target are not of the same type.");

auto const scatter_map_size = std::distance(scatter_map_begin, scatter_map_end);
if (scatter_map_size == 0) { return std::make_unique<column>(target, stream, mr); }

structs_column_view const structs_src(source);
structs_column_view const structs_target(target);
std::vector<std::unique_ptr<column>> output_struct_members(structs_src.num_children());

std::transform(structs_src.child_begin(),
structs_src.child_end(),
structs_target.child_begin(),
output_struct_members.begin(),
[&scatter_map_begin, &scatter_map_end, stream, mr](auto const& source_col,
auto const& target_col) {
return type_dispatcher<dispatch_storage_type>(source_col.type(),
column_scatterer{},
source_col,
scatter_map_begin,
scatter_map_end,
target_col,
stream,
mr);
});

// We still need to call `gather_bitmask` even when the source's children are not nullable,
// as if the target's children have null_masks, those null_masks need to be updated after
// being scattered onto
auto const child_nullable = std::any_of(structs_src.child_begin(),
structs_src.child_end(),
[](auto const& col) { return col.nullable(); }) or
std::any_of(structs_target.child_begin(),
structs_target.child_end(),
[](auto const& col) { return col.nullable(); });
if (child_nullable) {
auto const gather_map =
scatter_to_gather(scatter_map_begin, scatter_map_end, source.size(), stream);
gather_bitmask(cudf::table_view{std::vector<cudf::column_view>{structs_src.child_begin(),
structs_src.child_end()}},
gather_map.begin(),
output_struct_members,
gather_bitmask_op::PASSTHROUGH,
stream,
mr);
}

// Need to put the result column in a vector to call `gather_bitmask`
std::vector<std::unique_ptr<column>> result;
result.emplace_back(cudf::make_structs_column(source.size(),
std::move(output_struct_members),
0,
rmm::device_buffer{0, stream, mr},
stream,
mr));

// Only gather bitmask from the target column for the rows that have not been scattered onto
// The bitmask from the source column will be gathered at the top level `scatter()` call
if (target.nullable()) {
auto const gather_map =
scatter_to_gather_complement(scatter_map_begin, scatter_map_end, target.size(), stream);
gather_bitmask(table_view{std::vector<cudf::column_view>{target}},
gather_map.begin(),
result,
gather_bitmask_op::PASSTHROUGH,
stream,
mr);
}

return std::move(result.front());
}
};

/**
* @brief Scatters the rows of the source table into a copy of the target table
* according to a scatter map.
Expand Down Expand Up @@ -282,10 +405,8 @@ std::unique_ptr<table> scatter(
// Transform negative indices to index + target size
auto updated_scatter_map_begin =
thrust::make_transform_iterator(scatter_map_begin, index_converter<MapType>{target.num_rows()});

auto updated_scatter_map_end =
thrust::make_transform_iterator(scatter_map_end, index_converter<MapType>{target.num_rows()});

auto result = std::vector<std::unique_ptr<column>>(target.num_columns());

std::transform(source.begin(),
Expand All @@ -303,11 +424,16 @@ std::unique_ptr<table> scatter(
mr);
});

auto gather_map = scatter_to_gather(
updated_scatter_map_begin, updated_scatter_map_end, target.num_rows(), stream);

gather_bitmask(source, gather_map.begin(), result, gather_bitmask_op::PASSTHROUGH, stream, mr);

// We still need to call `gather_bitmask` even when the source columns are not nullable,
// as if the target has null_mask, that null_mask needs to be updated after scattering
auto const nullable =
std::any_of(source.begin(), source.end(), [](auto const& col) { return col.nullable(); }) or
std::any_of(target.begin(), target.end(), [](auto const& col) { return col.nullable(); });
if (nullable) {
auto const gather_map = scatter_to_gather(
updated_scatter_map_begin, updated_scatter_map_end, target.num_rows(), stream);
gather_bitmask(source, gather_map.begin(), result, gather_bitmask_op::PASSTHROUGH, stream, mr);
}
return std::make_unique<table>(std::move(result));
}
} // namespace detail
Expand Down
23 changes: 12 additions & 11 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -201,24 +201,25 @@ ConfigureTest(SORT_TEST
###################################################################################################
# - copying tests ---------------------------------------------------------------------------------
ConfigureTest(COPYING_TEST
copying/utility_tests.cpp
copying/concatenate_tests.cu
copying/copy_range_tests.cpp
copying/copy_tests.cu
copying/detail_gather_tests.cu
copying/gather_struct_tests.cu
copying/gather_tests.cu
copying/gather_str_tests.cu
copying/gather_list_tests.cu
copying/segmented_gather_list_tests.cpp
copying/gather_struct_tests.cu
copying/detail_gather_tests.cu
copying/get_value_tests.cpp
copying/pack_tests.cu
copying/sample_tests.cpp
copying/scatter_tests.cpp
copying/scatter_list_tests.cu
copying/copy_range_tests.cpp
copying/scatter_struct_tests.cu
copying/segmented_gather_list_tests.cpp
copying/shift_tests.cpp
copying/slice_tests.cpp
copying/split_tests.cpp
copying/copy_tests.cu
copying/shift_tests.cpp
copying/get_value_tests.cpp
copying/sample_tests.cpp
copying/concatenate_tests.cu
copying/pack_tests.cu)
copying/utility_tests.cpp)

###################################################################################################
# - utilities tests -------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit e7153bb

Please sign in to comment.