Skip to content

Commit

Permalink
some refinement
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed May 12, 2021
1 parent 5511580 commit 44fbfb8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
23 changes: 13 additions & 10 deletions cpp/src/groupby/sort/group_collect.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ namespace groupby {
namespace detail {
/**
* @brief Purge null entries in grouped values, and adjust group offsets.
*
* @param values Grouped values to be purged
* @param offsets Offsets of groups' starting points
* @param num_groups Number of groups
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory
* @return Pair of null eliminated grouped values corresponding offsets
*/
std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
column_view const &values,
Expand All @@ -50,17 +57,16 @@ std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
auto null_purged_entries =
cudf::detail::copy_if(table_view{{values}}, not_null_pred, stream, mr)->release();

std::unique_ptr<column> &null_purged_values = null_purged_entries[0];
auto null_purged_values = std::move(null_purged_entries.front());

// Recalculate offsets after null entries are purged.
auto null_purged_sizes = make_numeric_column(
data_type{type_to_id<size_type>()}, num_groups, mask_state::UNALLOCATED, stream, mr);
rmm::device_uvector<size_type> null_purged_sizes(num_groups, stream);

thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(num_groups),
null_purged_sizes->mutable_view().template begin<size_type>(),
null_purged_sizes.begin(),
[d_offsets = offsets.template begin<size_type>(), not_null_pred] __device__(auto i) {
return thrust::count_if(thrust::seq,
thrust::make_counting_iterator<size_type>(d_offsets[i]),
Expand All @@ -69,10 +75,7 @@ std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
});

auto null_purged_offsets = strings::detail::make_offsets_child_column(
null_purged_sizes->view().template begin<size_type>(),
null_purged_sizes->view().template end<size_type>(),
stream,
mr);
null_purged_sizes.cbegin(), null_purged_sizes.cend(), stream, mr);

return std::make_pair<std::unique_ptr<column>, std::unique_ptr<column>>(
std::move(null_purged_values), std::move(null_purged_offsets));
Expand All @@ -88,12 +91,12 @@ std::unique_ptr<column> group_collect(column_view const &values,
auto [child_column,
offsets_column] = [null_handling, num_groups, &values, &group_offsets, stream, mr] {
auto offsets_column = make_numeric_column(
data_type(type_to_id<size_type>()), num_groups + 1, mask_state::UNALLOCATED, stream, mr);
data_type(type_to_id<offset_type>()), num_groups + 1, mask_state::UNALLOCATED, stream, mr);

thrust::copy(rmm::exec_policy(stream),
group_offsets.begin(),
group_offsets.end(),
offsets_column->mutable_view().template begin<size_type>());
offsets_column->mutable_view().template begin<offset_type>());

// If column of grouped values contains null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/groupby/sort/group_reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ std::unique_ptr<column> group_nth_element(column_view const& values,
* @param num_groups Number of groups
* @param null_handling Exclude nulls while counting if null_policy::EXCLUDE,
* Include nulls if null_policy::INCLUDE.
* @param mr Device memory resource used to allocate the returned column's device memory
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory
*/
std::unique_ptr<column> group_collect(column_view const& values,
cudf::device_span<size_type const> group_offsets,
Expand Down
11 changes: 4 additions & 7 deletions cpp/tests/groupby/collect_list_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,15 @@ TYPED_TEST(groupby_collect_list_test, CollectListsWithNullExclusion)
using K = int32_t;
using V = TypeParam;

using LCW = cudf::test::lists_column_wrapper<TypeParam, int32_t>;
using LCW = cudf::test::lists_column_wrapper<V, int32_t>;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 2, 2, 3, 3, 4, 4};
const bool validity_mask[8] = {true, false, false, true, true, true, false, false};
auto validity = cudf::detail::make_counting_transform_iterator(
0, [&validity_mask](auto i) { return validity_mask[i]; });
lists_column_wrapper<V, int32_t> values{
{{1, 2}, {3, 4}, {5, 6, 7}, LCW{}, {9, 10}, {11}, {20, 30, 40}, LCW{}}, validity};
const bool validity_mask[] = {true, false, false, true, true, true, false, false};
LCW values{{{1, 2}, {3, 4}, {5, 6, 7}, LCW{}, {9, 10}, {11}, {20, 30, 40}, LCW{}}, validity_mask};

fixed_width_column_wrapper<K, int32_t> expect_keys{1, 2, 3, 4};

lists_column_wrapper<V, int32_t> expect_vals{{{1, 2}}, {LCW{}}, {{9, 10}, {11}}, {}};
LCW expect_vals{{{1, 2}}, {LCW{}}, {{9, 10}, {11}}, {}};

auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE);
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
Expand Down

0 comments on commit 44fbfb8

Please sign in to comment.