Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support exclude null_policy for collect list/set in groupby #8044

Merged
merged 9 commits into from
May 13, 2021
20 changes: 12 additions & 8 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,15 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation
auto null_handling =
dynamic_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
agg.do_hash();
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");

if (cache.has_result(col_idx, agg)) return;

auto result = detail::group_collect(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);

cache.add_result(col_idx, agg, std::move(result));
};
Expand All @@ -385,13 +387,15 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
{
auto const null_handling =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_SET aggregation.");

if (cache.has_result(col_idx, agg)) { return; }

auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto const collect_result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);
auto const nulls_equal =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
Expand Down
82 changes: 75 additions & 7 deletions cpp/src/groupby/sort/group_collect.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,30 +17,98 @@
#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/copy_if.cuh>
#include <cudf/detail/gather.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <memory>

namespace cudf {
namespace groupby {
namespace detail {
/**
* @brief Purge null entries in grouped values, and adjust group offsets.
*/
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
column_view const &values,
column_view const &offsets,
size_type num_groups,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr)
{
auto values_device_view = column_device_view::create(values, stream);

auto not_null_pred = [d_value = *values_device_view] __device__(auto i) {
return d_value.is_valid_nocheck(i);
};

// Purge null entries in grouped values.
auto null_purged_entries =
cudf::detail::copy_if(table_view{{values}}, not_null_pred, stream, mr)->release();

std::unique_ptr<column> &null_purged_values = null_purged_entries[0];
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved

// Recalculate offsets after null entries are purged.
auto null_purged_sizes = make_numeric_column(
data_type{type_to_id<size_type>()}, num_groups, mask_state::UNALLOCATED, stream, mr);
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(num_groups),
null_purged_sizes->mutable_view().template begin<size_type>(),
[d_offsets = offsets.template begin<size_type>(), not_null_pred] __device__(auto i) {
return thrust::count_if(thrust::seq,
thrust::make_counting_iterator<size_type>(d_offsets[i]),
thrust::make_counting_iterator<size_type>(d_offsets[i + 1]),
not_null_pred);
});

auto null_purged_offsets = strings::detail::make_offsets_child_column(
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
null_purged_sizes->view().template begin<size_type>(),
null_purged_sizes->view().template end<size_type>(),
stream,
mr);

return std::make_pair<std::unique_ptr<column>, std::unique_ptr<column>>(
std::move(null_purged_values), std::move(null_purged_offsets));
}

std::unique_ptr<column> group_collect(column_view const &values,
cudf::device_span<size_type const> group_offsets,
size_type num_groups,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr)
{
rmm::device_buffer offsets_data(
group_offsets.data(), group_offsets.size() * sizeof(cudf::size_type), stream, mr);
auto [child_column,
offsets_column] = [null_handling, num_groups, &values, &group_offsets, stream, mr] {
auto offsets_column = make_numeric_column(
data_type(type_to_id<size_type>()), num_groups + 1, mask_state::UNALLOCATED, stream, mr);
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved

thrust::copy(rmm::exec_policy(stream),
group_offsets.begin(),
group_offsets.end(),
offsets_column->mutable_view().template begin<size_type>());

auto offsets = std::make_unique<cudf::column>(
cudf::data_type(cudf::type_to_id<cudf::size_type>()), num_groups + 1, std::move(offsets_data));
// If column of grouped values contains null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
if (null_handling == null_policy::EXCLUDE && values.has_nulls()) {
return cudf::groupby::detail::purge_null_entries(
values, offsets_column->view(), num_groups, stream, mr);
} else {
return std::make_pair(std::make_unique<cudf::column>(values, stream, mr),
std::move(offsets_column));
}
}();

return make_lists_column(num_groups,
std::move(offsets),
std::make_unique<cudf::column>(values, stream, mr),
std::move(offsets_column),
std::move(child_column),
0,
rmm::device_buffer{0, stream, mr},
stream,
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/groupby/sort/group_reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,15 @@ std::unique_ptr<column> group_nth_element(column_view const& values,
* @param values Grouped values to collect
* @param group_offsets Offsets of groups' starting points within @p values
* @param num_groups Number of groups
* @param null_handling Exclude nulls while counting if null_policy::EXCLUDE,
* Include nulls if null_policy::INCLUDE.
* @param mr Device memory resource used to allocate the returned column's device memory
* @param stream CUDA stream used for device memory operations and kernel launches.
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
*/
std::unique_ptr<column> group_collect(column_view const& values,
cudf::device_span<size_type const> group_offsets,
size_type num_groups,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);

Expand Down
59 changes: 40 additions & 19 deletions cpp/tests/groupby/collect_list_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ TYPED_TEST(groupby_collect_list_test, CollectWithNulls)
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 1, 2, 2, 3, 3, 4, 4};

fixed_width_column_wrapper<V, int32_t> values{
{1, 2, 3, 4, 5, 6, 7, 8, 9}, {false, true, false, true, false, false, false, true, true}};

fixed_width_column_wrapper<K, int32_t> expect_keys{1, 2, 3, 4};

lists_column_wrapper<V, int32_t> expect_vals{{2}, {4}, {}, {8, 9}};

auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE);
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectLists)
{
using K = int32_t;
Expand All @@ -87,6 +105,28 @@ TYPED_TEST(groupby_collect_list_test, CollectLists)
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectListsWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

using LCW = cudf::test::lists_column_wrapper<TypeParam, int32_t>;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 2, 2, 3, 3, 4, 4};
const bool validity_mask[8] = {true, false, false, true, true, true, false, false};
auto validity = cudf::detail::make_counting_transform_iterator(
0, [&validity_mask](auto i) { return validity_mask[i]; });
lists_column_wrapper<V, int32_t> values{
{{1, 2}, {3, 4}, {5, 6, 7}, LCW{}, {9, 10}, {11}, {20, 30, 40}, LCW{}}, validity};
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved

fixed_width_column_wrapper<K, int32_t> expect_keys{1, 2, 3, 4};

lists_column_wrapper<V, int32_t> expect_vals{{{1, 2}}, {LCW{}}, {{9, 10}, {11}}, {}};
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved

auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE);
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, dictionary)
{
using K = int32_t;
Expand All @@ -109,24 +149,5 @@ TYPED_TEST(groupby_collect_list_test, dictionary)
keys, vals, expect_keys, expect_vals->view(), cudf::make_collect_list_aggregation());
}

TYPED_TEST(groupby_collect_list_test, CollectFailsWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 2, 2, 3, 3};
groupby::groupby gby{table_view{{keys}}};

fixed_width_column_wrapper<V, int32_t> values{{1, 2, 3, 4, 5, 6},
{true, false, true, false, true, false}};

std::vector<groupby::aggregation_request> agg_requests(1);
agg_requests[0].values = values;
agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE));

CUDF_EXPECT_THROW_MESSAGE(gby.aggregate(agg_requests),
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");
}

} // namespace test
} // namespace cudf
59 changes: 34 additions & 25 deletions cpp/tests/groupby/collect_set_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,19 @@ namespace test {
#define LCL_V cudf::test::lists_column_wrapper<TypeParam, int32_t>
#define LCL_S cudf::test::lists_column_wrapper<cudf::string_view>
#define VALIDITY std::initializer_list<bool>
#define COLLECT_SET cudf::make_collect_set_aggregation()
#define COLLECT_SET_NULL_UNEQUAL \
cudf::make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL)

struct CollectSetTest : public cudf::test::BaseFixture {
static auto collect_set() { return cudf::make_collect_set_aggregation(); }

static auto collect_set_null_unequal()
{
return cudf::make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL);
}

static auto collect_set_null_exclude()
{
return cudf::make_collect_set_aggregation(null_policy::EXCLUDE);
}
};

template <typename V>
Expand All @@ -47,17 +55,6 @@ using FixedWidthTypesNotBool = cudf::test::Concat<cudf::test::IntegralTypesNotBo
cudf::test::TimestampTypes>;
TYPED_TEST_CASE(CollectSetTypedTest, FixedWidthTypesNotBool);

TYPED_TEST(CollectSetTypedTest, ExceptionTests)
{
std::vector<groupby::aggregation_request> agg_requests(1);
agg_requests[0].values = COL_V{{1, 2, 3, 4, 5, 6}, {true, false, true, false, true, false}};
agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE));

// groupby cannot exclude nulls
groupby::groupby gby{table_view{{COL_K{1, 1, 2, 2, 3, 3}}}};
EXPECT_THROW(gby.aggregate(agg_requests), cudf::logic_error);
}

TYPED_TEST(CollectSetTypedTest, TrivialInput)
{
// Empty input
Expand All @@ -70,7 +67,7 @@ TYPED_TEST(CollectSetTypedTest, TrivialInput)
COL_V vals{10};
COL_K keys_expected{1};
LCL_V vals_expected{LCL_V{10}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

// Non-repeated keys
Expand All @@ -79,7 +76,7 @@ TYPED_TEST(CollectSetTypedTest, TrivialInput)
COL_V vals{20, 10};
COL_K keys_expected{1, 2};
LCL_V vals_expected{LCL_V{10}, LCL_V{20}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -91,7 +88,7 @@ TYPED_TEST(CollectSetTypedTest, TypicalInput)
COL_V vals{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31};
COL_K keys_expected{1, 2, 3};
LCL_V vals_expected{{10, 11}, {20, 21}, {30, 31, 32, 33}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

// Expect the result keys to be sorted by sort-based groupby
Expand All @@ -100,7 +97,7 @@ TYPED_TEST(CollectSetTypedTest, TypicalInput)
COL_V vals{40, 10, 20, 40, 30, 30, 20, 11};
COL_K keys_expected{1, 2, 3, 4};
LCL_V vals_expected{{10, 11}, {20}, {30}, {40}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -114,14 +111,14 @@ TYPED_TEST(CollectSetTypedTest, SlicedColumnsInput)
auto const vals = cudf::slice(vals_original, {0, 4})[0]; // { 10, 11, 10, 10 }
auto const keys_expected = COL_K{1};
auto const vals_expected = LCL_V{{10, 11}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
{
auto const keys = cudf::slice(keys_original, {2, 10})[0]; // { 1, 1, 2, 2, 2, 2, 3, 3 }
auto const vals = cudf::slice(vals_original, {2, 10})[0]; // { 10, 10, 20, 21, 21, 20, 30, 33 }
auto const keys_expected = COL_K{1, 2, 3};
auto const vals_expected = LCL_V{{10}, {20, 21}, {30, 33}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -147,7 +144,7 @@ TEST_F(CollectSetTest, StringInput)
LCL_S vals_expected{{"String 1, first", "String 1, second"},
{"String 2, first", "String 2, second"},
{"String 3, first", "String 3, second"}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
Expand All @@ -167,13 +164,19 @@ TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
LCL_V vals_expected{{{10, null}, VALIDITY{true, false}},
{{20, null}, VALIDITY{true, false}},
{{30, 31}, VALIDITY{true, true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());

// All nulls per key are kept (nulls are put at the end of each list)
vals_expected = LCL_V{{{10, null, null}, VALIDITY{true, false, false}},
{{20, null, null, null}, VALIDITY{true, false, false, false}},
{{30, 31}, VALIDITY{true, true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL);
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal());

// All nulls per key are excluded
vals_expected = LCL_V{{10}, {20}, {30, 31}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude());
}

// Expect the result keys to be sorted by sort-based groupby
Expand All @@ -188,14 +191,20 @@ TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
{{20, 21}, VALIDITY{true, true}},
{{null}, VALIDITY{false}},
{{40}, VALIDITY{true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());

// All nulls per key are kept (nulls are put at the end of each list)
vals_expected = LCL_V{{{10, null}, VALIDITY{true, false}},
{{20, 21}, VALIDITY{true, true}},
{{null, null, null, null}, VALIDITY{false, false, false, false}},
{{40}, VALIDITY{true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL);
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal());

// All nulls per key are excluded
vals_expected = LCL_V{{10}, {20, 21}, {}, {40}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude());
}
}

Expand Down