Skip to content

Commit

Permalink
Support exclude null_policy for collect list/set in groupby (#8044)
Browse files Browse the repository at this point in the history
This pull request is to support `null_policy::EXCLUDE` for `collect_list`/`collect_set` in groupBy context, which is requested in  #7777.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - David Wendt (https://github.com/davidwendt)
  - MithunR (https://github.com/mythrocks)

URL: #8044
  • Loading branch information
sperlingxx authored May 13, 2021
1 parent 6b92e25 commit f8d7de4
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 60 deletions.
20 changes: 12 additions & 8 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,15 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation
auto null_handling =
dynamic_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
agg.do_hash();
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");

if (cache.has_result(col_idx, agg)) return;

auto result = detail::group_collect(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);

cache.add_result(col_idx, agg, std::move(result));
};
Expand All @@ -385,13 +387,15 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
{
auto const null_handling =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_SET aggregation.");

if (cache.has_result(col_idx, agg)) { return; }

auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto const collect_result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);
auto const nulls_equal =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
Expand Down
85 changes: 78 additions & 7 deletions cpp/src/groupby/sort/group_collect.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,30 +17,101 @@
#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/copy_if.cuh>
#include <cudf/detail/gather.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <memory>

namespace cudf {
namespace groupby {
namespace detail {
/**
* @brief Purge null entries in grouped values, and adjust group offsets.
*
* @param values Grouped values to be purged
* @param offsets Offsets of groups' starting points
* @param num_groups Number of groups
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory
* @return Pair of null-eliminated grouped values and corresponding offsets
*/
std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
column_view const &values,
column_view const &offsets,
size_type num_groups,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr)
{
auto values_device_view = column_device_view::create(values, stream);

auto not_null_pred = [d_value = *values_device_view] __device__(auto i) {
return d_value.is_valid_nocheck(i);
};

// Purge null entries in grouped values.
auto null_purged_entries =
cudf::detail::copy_if(table_view{{values}}, not_null_pred, stream, mr)->release();

auto null_purged_values = std::move(null_purged_entries.front());

// Recalculate offsets after null entries are purged.
rmm::device_uvector<size_type> null_purged_sizes(num_groups, stream);

thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(num_groups),
null_purged_sizes.begin(),
[d_offsets = offsets.template begin<size_type>(), not_null_pred] __device__(auto i) {
return thrust::count_if(thrust::seq,
thrust::make_counting_iterator<size_type>(d_offsets[i]),
thrust::make_counting_iterator<size_type>(d_offsets[i + 1]),
not_null_pred);
});

auto null_purged_offsets = strings::detail::make_offsets_child_column(
null_purged_sizes.cbegin(), null_purged_sizes.cend(), stream, mr);

return std::make_pair<std::unique_ptr<column>, std::unique_ptr<column>>(
std::move(null_purged_values), std::move(null_purged_offsets));
}

std::unique_ptr<column> group_collect(column_view const &values,
cudf::device_span<size_type const> group_offsets,
size_type num_groups,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr)
{
rmm::device_buffer offsets_data(
group_offsets.data(), group_offsets.size() * sizeof(cudf::size_type), stream, mr);
auto [child_column,
offsets_column] = [null_handling, num_groups, &values, &group_offsets, stream, mr] {
auto offsets_column = make_numeric_column(
data_type(type_to_id<offset_type>()), num_groups + 1, mask_state::UNALLOCATED, stream, mr);

thrust::copy(rmm::exec_policy(stream),
group_offsets.begin(),
group_offsets.end(),
offsets_column->mutable_view().template begin<offset_type>());

auto offsets = std::make_unique<cudf::column>(
cudf::data_type(cudf::type_to_id<cudf::size_type>()), num_groups + 1, std::move(offsets_data));
// If column of grouped values contains null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
if (null_handling == null_policy::EXCLUDE && values.has_nulls()) {
return cudf::groupby::detail::purge_null_entries(
values, offsets_column->view(), num_groups, stream, mr);
} else {
return std::make_pair(std::make_unique<cudf::column>(values, stream, mr),
std::move(offsets_column));
}
}();

return make_lists_column(num_groups,
std::move(offsets),
std::make_unique<cudf::column>(values, stream, mr),
std::move(offsets_column),
std::move(child_column),
0,
rmm::device_buffer{0, stream, mr},
stream,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/groupby/sort/group_reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,15 @@ std::unique_ptr<column> group_nth_element(column_view const& values,
* @param values Grouped values to collect
* @param group_offsets Offsets of groups' starting points within @p values
* @param num_groups Number of groups
* @param mr Device memory resource used to allocate the returned column's device memory
* @param null_handling Exclude nulls while counting if null_policy::EXCLUDE,
* Include nulls if null_policy::INCLUDE.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory
*/
std::unique_ptr<column> group_collect(column_view const& values,
cudf::device_span<size_type const> group_offsets,
size_type num_groups,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);

Expand Down
56 changes: 37 additions & 19 deletions cpp/tests/groupby/collect_list_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ TYPED_TEST(groupby_collect_list_test, CollectWithNulls)
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 1, 2, 2, 3, 3, 4, 4};

fixed_width_column_wrapper<V, int32_t> values{
{1, 2, 3, 4, 5, 6, 7, 8, 9}, {false, true, false, true, false, false, false, true, true}};

fixed_width_column_wrapper<K, int32_t> expect_keys{1, 2, 3, 4};

lists_column_wrapper<V, int32_t> expect_vals{{2}, {4}, {}, {8, 9}};

auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE);
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectLists)
{
using K = int32_t;
Expand All @@ -87,6 +105,25 @@ TYPED_TEST(groupby_collect_list_test, CollectLists)
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectListsWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

using LCW = cudf::test::lists_column_wrapper<V, int32_t>;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 2, 2, 3, 3, 4, 4};
const bool validity_mask[] = {true, false, false, true, true, true, false, false};
LCW values{{{1, 2}, {3, 4}, {5, 6, 7}, LCW{}, {9, 10}, {11}, {20, 30, 40}, LCW{}}, validity_mask};

fixed_width_column_wrapper<K, int32_t> expect_keys{1, 2, 3, 4};

LCW expect_vals{{{1, 2}}, {LCW{}}, {{9, 10}, {11}}, {}};

auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE);
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, dictionary)
{
using K = int32_t;
Expand All @@ -109,24 +146,5 @@ TYPED_TEST(groupby_collect_list_test, dictionary)
keys, vals, expect_keys, expect_vals->view(), cudf::make_collect_list_aggregation());
}

TYPED_TEST(groupby_collect_list_test, CollectFailsWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 2, 2, 3, 3};
groupby::groupby gby{table_view{{keys}}};

fixed_width_column_wrapper<V, int32_t> values{{1, 2, 3, 4, 5, 6},
{true, false, true, false, true, false}};

std::vector<groupby::aggregation_request> agg_requests(1);
agg_requests[0].values = values;
agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE));

CUDF_EXPECT_THROW_MESSAGE(gby.aggregate(agg_requests),
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");
}

} // namespace test
} // namespace cudf
59 changes: 34 additions & 25 deletions cpp/tests/groupby/collect_set_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,19 @@ namespace test {
#define LCL_V cudf::test::lists_column_wrapper<TypeParam, int32_t>
#define LCL_S cudf::test::lists_column_wrapper<cudf::string_view>
#define VALIDITY std::initializer_list<bool>
#define COLLECT_SET cudf::make_collect_set_aggregation()
#define COLLECT_SET_NULL_UNEQUAL \
cudf::make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL)

struct CollectSetTest : public cudf::test::BaseFixture {
static auto collect_set() { return cudf::make_collect_set_aggregation(); }

static auto collect_set_null_unequal()
{
return cudf::make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL);
}

static auto collect_set_null_exclude()
{
return cudf::make_collect_set_aggregation(null_policy::EXCLUDE);
}
};

template <typename V>
Expand All @@ -47,17 +55,6 @@ using FixedWidthTypesNotBool = cudf::test::Concat<cudf::test::IntegralTypesNotBo
cudf::test::TimestampTypes>;
TYPED_TEST_CASE(CollectSetTypedTest, FixedWidthTypesNotBool);

TYPED_TEST(CollectSetTypedTest, ExceptionTests)
{
std::vector<groupby::aggregation_request> agg_requests(1);
agg_requests[0].values = COL_V{{1, 2, 3, 4, 5, 6}, {true, false, true, false, true, false}};
agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE));

// groupby cannot exclude nulls
groupby::groupby gby{table_view{{COL_K{1, 1, 2, 2, 3, 3}}}};
EXPECT_THROW(gby.aggregate(agg_requests), cudf::logic_error);
}

TYPED_TEST(CollectSetTypedTest, TrivialInput)
{
// Empty input
Expand All @@ -70,7 +67,7 @@ TYPED_TEST(CollectSetTypedTest, TrivialInput)
COL_V vals{10};
COL_K keys_expected{1};
LCL_V vals_expected{LCL_V{10}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

// Non-repeated keys
Expand All @@ -79,7 +76,7 @@ TYPED_TEST(CollectSetTypedTest, TrivialInput)
COL_V vals{20, 10};
COL_K keys_expected{1, 2};
LCL_V vals_expected{LCL_V{10}, LCL_V{20}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -91,7 +88,7 @@ TYPED_TEST(CollectSetTypedTest, TypicalInput)
COL_V vals{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31};
COL_K keys_expected{1, 2, 3};
LCL_V vals_expected{{10, 11}, {20, 21}, {30, 31, 32, 33}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

// Expect the result keys to be sorted by sort-based groupby
Expand All @@ -100,7 +97,7 @@ TYPED_TEST(CollectSetTypedTest, TypicalInput)
COL_V vals{40, 10, 20, 40, 30, 30, 20, 11};
COL_K keys_expected{1, 2, 3, 4};
LCL_V vals_expected{{10, 11}, {20}, {30}, {40}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -114,14 +111,14 @@ TYPED_TEST(CollectSetTypedTest, SlicedColumnsInput)
auto const vals = cudf::slice(vals_original, {0, 4})[0]; // { 10, 11, 10, 10 }
auto const keys_expected = COL_K{1};
auto const vals_expected = LCL_V{{10, 11}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
{
auto const keys = cudf::slice(keys_original, {2, 10})[0]; // { 1, 1, 2, 2, 2, 2, 3, 3 }
auto const vals = cudf::slice(vals_original, {2, 10})[0]; // { 10, 10, 20, 21, 21, 20, 30, 33 }
auto const keys_expected = COL_K{1, 2, 3};
auto const vals_expected = LCL_V{{10}, {20, 21}, {30, 33}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -147,7 +144,7 @@ TEST_F(CollectSetTest, StringInput)
LCL_S vals_expected{{"String 1, first", "String 1, second"},
{"String 2, first", "String 2, second"},
{"String 3, first", "String 3, second"}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
Expand All @@ -167,13 +164,19 @@ TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
LCL_V vals_expected{{{10, null}, VALIDITY{true, false}},
{{20, null}, VALIDITY{true, false}},
{{30, 31}, VALIDITY{true, true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());

// All nulls per key are kept (nulls are put at the end of each list)
vals_expected = LCL_V{{{10, null, null}, VALIDITY{true, false, false}},
{{20, null, null, null}, VALIDITY{true, false, false, false}},
{{30, 31}, VALIDITY{true, true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL);
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal());

// All nulls per key are excluded
vals_expected = LCL_V{{10}, {20}, {30, 31}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude());
}

// Expect the result keys to be sorted by sort-based groupby
Expand All @@ -188,14 +191,20 @@ TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
{{20, 21}, VALIDITY{true, true}},
{{null}, VALIDITY{false}},
{{40}, VALIDITY{true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());

// All nulls per key are kept (nulls are put at the end of each list)
vals_expected = LCL_V{{{10, null}, VALIDITY{true, false}},
{{20, 21}, VALIDITY{true, true}},
{{null, null, null, null}, VALIDITY{false, false, false, false}},
{{40}, VALIDITY{true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL);
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal());

// All nulls per key are excluded
vals_expected = LCL_V{{10}, {20, 21}, {}, {40}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude());
}
}

Expand Down

0 comments on commit f8d7de4

Please sign in to comment.