Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support exclude null_policy for collect list/set in groupby #8044

Merged
merged 9 commits into from
May 13, 2021
20 changes: 12 additions & 8 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,15 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation
auto null_handling =
dynamic_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
agg.do_hash();
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");

if (cache.has_result(col_idx, agg)) return;

auto result = detail::group_collect(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);

cache.add_result(col_idx, agg, std::move(result));
};
Expand All @@ -385,13 +387,15 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
{
auto const null_handling =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;
CUDF_EXPECTS(null_handling == null_policy::INCLUDE,
"null exclusion is not supported on groupby COLLECT_SET aggregation.");

if (cache.has_result(col_idx, agg)) { return; }

auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto const collect_result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);
auto const nulls_equal =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
Expand Down
85 changes: 78 additions & 7 deletions cpp/src/groupby/sort/group_collect.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,30 +17,101 @@
#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/copy_if.cuh>
#include <cudf/detail/gather.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <memory>

namespace cudf {
namespace groupby {
namespace detail {
/**
* @brief Purge null entries in grouped values, and adjust group offsets.
*
* @param values Grouped values to be purged
* @param offsets Offsets of groups' starting points
* @param num_groups Number of groups
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory
* @return Pair of null eliminated grouped values corresponding offsets
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
*/
std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
column_view const &values,
column_view const &offsets,
size_type num_groups,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr)
{
auto values_device_view = column_device_view::create(values, stream);

auto not_null_pred = [d_value = *values_device_view] __device__(auto i) {
return d_value.is_valid_nocheck(i);
};

// Purge null entries in grouped values.
auto null_purged_entries =
cudf::detail::copy_if(table_view{{values}}, not_null_pred, stream, mr)->release();

auto null_purged_values = std::move(null_purged_entries.front());

// Recalculate offsets after null entries are purged.
rmm::device_uvector<size_type> null_purged_sizes(num_groups, stream);

thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(num_groups),
null_purged_sizes.begin(),
[d_offsets = offsets.template begin<size_type>(), not_null_pred] __device__(auto i) {
return thrust::count_if(thrust::seq,
thrust::make_counting_iterator<size_type>(d_offsets[i]),
thrust::make_counting_iterator<size_type>(d_offsets[i + 1]),
not_null_pred);
});

auto null_purged_offsets = strings::detail::make_offsets_child_column(
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
null_purged_sizes.cbegin(), null_purged_sizes.cend(), stream, mr);

return std::make_pair<std::unique_ptr<column>, std::unique_ptr<column>>(
std::move(null_purged_values), std::move(null_purged_offsets));
}

std::unique_ptr<column> group_collect(column_view const &values,
cudf::device_span<size_type const> group_offsets,
size_type num_groups,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr)
{
rmm::device_buffer offsets_data(
group_offsets.data(), group_offsets.size() * sizeof(cudf::size_type), stream, mr);
auto [child_column,
offsets_column] = [null_handling, num_groups, &values, &group_offsets, stream, mr] {
auto offsets_column = make_numeric_column(
data_type(type_to_id<offset_type>()), num_groups + 1, mask_state::UNALLOCATED, stream, mr);

thrust::copy(rmm::exec_policy(stream),
group_offsets.begin(),
group_offsets.end(),
offsets_column->mutable_view().template begin<offset_type>());

auto offsets = std::make_unique<cudf::column>(
cudf::data_type(cudf::type_to_id<cudf::size_type>()), num_groups + 1, std::move(offsets_data));
// If column of grouped values contains null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
if (null_handling == null_policy::EXCLUDE && values.has_nulls()) {
return cudf::groupby::detail::purge_null_entries(
values, offsets_column->view(), num_groups, stream, mr);
} else {
return std::make_pair(std::make_unique<cudf::column>(values, stream, mr),
std::move(offsets_column));
}
}();

return make_lists_column(num_groups,
std::move(offsets),
std::make_unique<cudf::column>(values, stream, mr),
std::move(offsets_column),
std::move(child_column),
0,
rmm::device_buffer{0, stream, mr},
stream,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/groupby/sort/group_reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,15 @@ std::unique_ptr<column> group_nth_element(column_view const& values,
* @param values Grouped values to collect
* @param group_offsets Offsets of groups' starting points within @p values
* @param num_groups Number of groups
* @param mr Device memory resource used to allocate the returned column's device memory
* @param null_handling Exclude nulls while counting if null_policy::EXCLUDE,
* Include nulls if null_policy::INCLUDE.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory
*/
std::unique_ptr<column> group_collect(column_view const& values,
cudf::device_span<size_type const> group_offsets,
size_type num_groups,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);

Expand Down
56 changes: 37 additions & 19 deletions cpp/tests/groupby/collect_list_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ TYPED_TEST(groupby_collect_list_test, CollectWithNulls)
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 1, 2, 2, 3, 3, 4, 4};

fixed_width_column_wrapper<V, int32_t> values{
{1, 2, 3, 4, 5, 6, 7, 8, 9}, {false, true, false, true, false, false, false, true, true}};

fixed_width_column_wrapper<K, int32_t> expect_keys{1, 2, 3, 4};

lists_column_wrapper<V, int32_t> expect_vals{{2}, {4}, {}, {8, 9}};

auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE);
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectLists)
{
using K = int32_t;
Expand All @@ -87,6 +105,25 @@ TYPED_TEST(groupby_collect_list_test, CollectLists)
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, CollectListsWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

using LCW = cudf::test::lists_column_wrapper<V, int32_t>;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 2, 2, 3, 3, 4, 4};
const bool validity_mask[] = {true, false, false, true, true, true, false, false};
LCW values{{{1, 2}, {3, 4}, {5, 6, 7}, LCW{}, {9, 10}, {11}, {20, 30, 40}, LCW{}}, validity_mask};

fixed_width_column_wrapper<K, int32_t> expect_keys{1, 2, 3, 4};

LCW expect_vals{{{1, 2}}, {LCW{}}, {{9, 10}, {11}}, {}};

auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE);
test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg));
}

TYPED_TEST(groupby_collect_list_test, dictionary)
{
using K = int32_t;
Expand All @@ -109,24 +146,5 @@ TYPED_TEST(groupby_collect_list_test, dictionary)
keys, vals, expect_keys, expect_vals->view(), cudf::make_collect_list_aggregation());
}

TYPED_TEST(groupby_collect_list_test, CollectFailsWithNullExclusion)
{
using K = int32_t;
using V = TypeParam;

fixed_width_column_wrapper<K, int32_t> keys{1, 1, 2, 2, 3, 3};
groupby::groupby gby{table_view{{keys}}};

fixed_width_column_wrapper<V, int32_t> values{{1, 2, 3, 4, 5, 6},
{true, false, true, false, true, false}};

std::vector<groupby::aggregation_request> agg_requests(1);
agg_requests[0].values = values;
agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE));

CUDF_EXPECT_THROW_MESSAGE(gby.aggregate(agg_requests),
"null exclusion is not supported on groupby COLLECT_LIST aggregation.");
}

} // namespace test
} // namespace cudf
59 changes: 34 additions & 25 deletions cpp/tests/groupby/collect_set_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,19 @@ namespace test {
#define LCL_V cudf::test::lists_column_wrapper<TypeParam, int32_t>
#define LCL_S cudf::test::lists_column_wrapper<cudf::string_view>
#define VALIDITY std::initializer_list<bool>
#define COLLECT_SET cudf::make_collect_set_aggregation()
#define COLLECT_SET_NULL_UNEQUAL \
cudf::make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL)

struct CollectSetTest : public cudf::test::BaseFixture {
static auto collect_set() { return cudf::make_collect_set_aggregation(); }

static auto collect_set_null_unequal()
{
return cudf::make_collect_set_aggregation(null_policy::INCLUDE, null_equality::UNEQUAL);
}

static auto collect_set_null_exclude()
{
return cudf::make_collect_set_aggregation(null_policy::EXCLUDE);
}
};

template <typename V>
Expand All @@ -47,17 +55,6 @@ using FixedWidthTypesNotBool = cudf::test::Concat<cudf::test::IntegralTypesNotBo
cudf::test::TimestampTypes>;
TYPED_TEST_CASE(CollectSetTypedTest, FixedWidthTypesNotBool);

TYPED_TEST(CollectSetTypedTest, ExceptionTests)
{
std::vector<groupby::aggregation_request> agg_requests(1);
agg_requests[0].values = COL_V{{1, 2, 3, 4, 5, 6}, {true, false, true, false, true, false}};
agg_requests[0].aggregations.push_back(cudf::make_collect_list_aggregation(null_policy::EXCLUDE));

// groupby cannot exclude nulls
groupby::groupby gby{table_view{{COL_K{1, 1, 2, 2, 3, 3}}}};
EXPECT_THROW(gby.aggregate(agg_requests), cudf::logic_error);
}

TYPED_TEST(CollectSetTypedTest, TrivialInput)
{
// Empty input
Expand All @@ -70,7 +67,7 @@ TYPED_TEST(CollectSetTypedTest, TrivialInput)
COL_V vals{10};
COL_K keys_expected{1};
LCL_V vals_expected{LCL_V{10}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

// Non-repeated keys
Expand All @@ -79,7 +76,7 @@ TYPED_TEST(CollectSetTypedTest, TrivialInput)
COL_V vals{20, 10};
COL_K keys_expected{1, 2};
LCL_V vals_expected{LCL_V{10}, LCL_V{20}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -91,7 +88,7 @@ TYPED_TEST(CollectSetTypedTest, TypicalInput)
COL_V vals{10, 11, 10, 10, 20, 21, 21, 20, 30, 33, 32, 31};
COL_K keys_expected{1, 2, 3};
LCL_V vals_expected{{10, 11}, {20, 21}, {30, 31, 32, 33}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

// Expect the result keys to be sorted by sort-based groupby
Expand All @@ -100,7 +97,7 @@ TYPED_TEST(CollectSetTypedTest, TypicalInput)
COL_V vals{40, 10, 20, 40, 30, 30, 20, 11};
COL_K keys_expected{1, 2, 3, 4};
LCL_V vals_expected{{10, 11}, {20}, {30}, {40}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -114,14 +111,14 @@ TYPED_TEST(CollectSetTypedTest, SlicedColumnsInput)
auto const vals = cudf::slice(vals_original, {0, 4})[0]; // { 10, 11, 10, 10 }
auto const keys_expected = COL_K{1};
auto const vals_expected = LCL_V{{10, 11}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
{
auto const keys = cudf::slice(keys_original, {2, 10})[0]; // { 1, 1, 2, 2, 2, 2, 3, 3 }
auto const vals = cudf::slice(vals_original, {2, 10})[0]; // { 10, 10, 20, 21, 21, 20, 30, 33 }
auto const keys_expected = COL_K{1, 2, 3};
auto const vals_expected = LCL_V{{10}, {20, 21}, {30, 33}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}
}

Expand All @@ -147,7 +144,7 @@ TEST_F(CollectSetTest, StringInput)
LCL_S vals_expected{{"String 1, first", "String 1, second"},
{"String 2, first", "String 2, second"},
{"String 3, first", "String 3, second"}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
Expand All @@ -167,13 +164,19 @@ TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
LCL_V vals_expected{{{10, null}, VALIDITY{true, false}},
{{20, null}, VALIDITY{true, false}},
{{30, 31}, VALIDITY{true, true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());

// All nulls per key are kept (nulls are put at the end of each list)
vals_expected = LCL_V{{{10, null, null}, VALIDITY{true, false, false}},
{{20, null, null, null}, VALIDITY{true, false, false, false}},
{{30, 31}, VALIDITY{true, true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL);
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal());

// All nulls per key are excluded
vals_expected = LCL_V{{10}, {20}, {30, 31}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude());
}

// Expect the result keys to be sorted by sort-based groupby
Expand All @@ -188,14 +191,20 @@ TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
{{20, 21}, VALIDITY{true, true}},
{{null}, VALIDITY{false}},
{{40}, VALIDITY{true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET);
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());

// All nulls per key are kept (nulls are put at the end of each list)
vals_expected = LCL_V{{{10, null}, VALIDITY{true, false}},
{{20, 21}, VALIDITY{true, true}},
{{null, null, null, null}, VALIDITY{false, false, false, false}},
{{40}, VALIDITY{true}}};
test_single_agg(keys, vals, keys_expected, vals_expected, COLLECT_SET_NULL_UNEQUAL);
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal());

// All nulls per key are excluded
vals_expected = LCL_V{{10}, {20, 21}, {}, {40}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude());
}
}

Expand Down