Skip to content

Commit

Permalink
Add dictionary support to libcudf groupby functions(#6585)
Browse files Browse the repository at this point in the history
Reference #5963 Add dictionary support to groupby.

- [x] argmax
- [x] argmin
- [x] collect
- [x] count
- [x] max
- [x] mean* 
- [x] median
- [x] min
- [x] nth element
- [x] nunique
- [x] quantile
- [x] std*
- [x] sum* 
- [x] var* 

* _not supported due to 10.2 compile segfault_

Authors:
  - davidwendt <[email protected]>

Approvers:
  - Jake Hemstad
  - Karthikeyan

URL: #6585
  • Loading branch information
davidwendt authored Jan 5, 2021
1 parent 7bf0505 commit 6828e2c
Show file tree
Hide file tree
Showing 25 changed files with 666 additions and 142 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- PR #6275 Update to official libcu++ on Github
- PR #6838 Fix `columns` & `index` handling in dataframe constructor
- PR #6750 Remove **kwargs from string/categorical methods
- PR #6585 Add dictionary support to libcudf groupby functions
- PR #6909 Support reading byte array backed decimal columns from parquet files
- PR #6939 Use simplified `rmm::exec_policy`
- PR #6512 Refactor rolling.cu to reduce compile time
Expand Down
59 changes: 59 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/utilities/device_atomics.cuh>
#include <cudf/detail/utilities/release_assert.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/table/table_device_view.cuh>

#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -172,6 +173,64 @@ struct update_target_element<
}
};

/**
* @brief Function object to update a single element in a target column using
* the dictionary key addressed by the specific index.
*
* `target[target_index] = d_dictionary.keys[d_dictionary.indices[source_index]]`
*/
struct update_target_from_dictionary {
template <typename KeyType,
std::enable_if_t<is_fixed_width<KeyType>() && !is_fixed_point<KeyType>()>* = nullptr>
__device__ void operator()(mutable_column_device_view& target,
size_type target_index,
column_device_view& d_dictionary,
size_type source_index) const noexcept
{
// This code will segfault in nvcc/ptxas 10.2 only
// https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=3186317
#if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2)
auto const keys = d_dictionary.child(cudf::dictionary_column_view::keys_column_index);
auto const value = keys.element<KeyType>(
static_cast<cudf::size_type>(d_dictionary.element<dictionary32>(source_index)));
using Target = target_type_t<KeyType, aggregation::SUM>;
atomicAdd(&target.element<Target>(target_index), static_cast<Target>(value));
#endif
}
template <typename KeyType,
std::enable_if_t<!is_fixed_width<KeyType>() || is_fixed_point<KeyType>()>* = nullptr>
__device__ void operator()(mutable_column_device_view& target,
size_type target_index,
column_device_view& d_dictionary,
size_type source_index) const noexcept {};
};

/**
* @brief Specialization function for dictionary type and aggregation SUM.
*
* @tparam target_has_nulls Indicates presence of null elements in `target`
* @tparam source_has_nulls Indicates presence of null elements in `source`.
*/
template <bool target_has_nulls, bool source_has_nulls>
struct update_target_element<dictionary32, aggregation::SUM, target_has_nulls, source_has_nulls> {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
size_type source_index) const noexcept
{
if (source_has_nulls and source.is_null(source_index)) { return; }

type_dispatcher(source.child(cudf::dictionary_column_view::keys_column_index).type(),
update_target_from_dictionary{},
target,
target_index,
source,
source_index);

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
}
};

template <typename Source, bool target_has_nulls, bool source_has_nulls>
struct update_target_element<
Source,
Expand Down
10 changes: 10 additions & 0 deletions cpp/include/cudf_test/column_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,16 @@ class dictionary_column_wrapper<std::string> : public detail::column_wrapper {
*/
operator dictionary_column_view() const { return cudf::dictionary_column_view{wrapped->view()}; }

/**
* @brief Access keys column view
*/
column_view keys() const { return cudf::dictionary_column_view{wrapped->view()}.keys(); }

/**
* @brief Access indices column view
*/
column_view indices() const { return cudf::dictionary_column_view{wrapped->view()}.indices(); }

/**
* @brief Default constructor initializes an empty dictionary column of strings
*/
Expand Down
45 changes: 34 additions & 11 deletions cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <cudf/detail/groupby.hpp>
#include <cudf/detail/groupby/sort_helper.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/groupby.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>
Expand Down Expand Up @@ -103,17 +104,39 @@ auto empty_results(std::vector<aggregation_request> const& requests)
/// Verifies the agg requested on the request's values is valid
void verify_valid_requests(std::vector<aggregation_request> const& requests)
{
CUDF_EXPECTS(std::all_of(requests.begin(),
requests.end(),
[](auto const& request) {
return std::all_of(request.aggregations.begin(),
request.aggregations.end(),
[&request](auto const& agg) {
return cudf::detail::is_valid_aggregation(
request.values.type(), agg->kind);
});
}),
"Invalid type/aggregation combination.");
CUDF_EXPECTS(
std::all_of(
requests.begin(),
requests.end(),
[](auto const& request) {
return std::all_of(
request.aggregations.begin(), request.aggregations.end(), [&request](auto const& agg) {
auto values_type = cudf::is_dictionary(request.values.type())
? cudf::dictionary_column_view(request.values).keys().type()
: request.values.type();
return cudf::detail::is_valid_aggregation(values_type, agg->kind);
});
}),
"Invalid type/aggregation combination.");

// The aggregations listed in the lambda below will not work with a values column of type
// dictionary if this is compiled with nvcc/ptxas 10.2.
// https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=3186317&cp=
#if (__CUDACC_VER_MAJOR__ == 10) and (__CUDACC_VER_MINOR__ == 2)
CUDF_EXPECTS(
std::all_of(
requests.begin(),
requests.end(),
[](auto const& request) {
return std::all_of(
request.aggregations.begin(), request.aggregations.end(), [&request](auto const& agg) {
return (!cudf::is_dictionary(request.values.type()) ||
!(agg->kind == aggregation::SUM or agg->kind == aggregation::MEAN or
agg->kind == aggregation::STD or agg->kind == aggregation::VARIANCE));
});
}),
"dictionary type not supported for this aggregation");
#endif
}

} // namespace
Expand Down
24 changes: 18 additions & 6 deletions cpp/src/groupby/hash/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <cudf/detail/unary.hpp>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/detail/utilities/hash_functions.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/groupby.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/table/row_operators.cuh>
Expand Down Expand Up @@ -104,6 +105,7 @@ template <typename Map>
class hash_compound_agg_finalizer final : public cudf::detail::aggregation_finalizer {
size_t col_idx;
column_view col;
data_type result_type;
cudf::detail::result_cache* sparse_results;
cudf::detail::result_cache* dense_results;
rmm::device_vector<size_type> const& gather_map;
Expand Down Expand Up @@ -135,6 +137,8 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final
stream(stream),
mr(mr)
{
result_type = cudf::is_dictionary(col.type()) ? cudf::dictionary_column_view(col).keys().type()
: col.type();
}

auto to_dense_agg_result(cudf::aggregation const& agg)
Expand Down Expand Up @@ -184,7 +188,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final
void visit(cudf::detail::min_aggregation const& agg) override
{
if (dense_results->has_result(col_idx, agg)) return;
if (col.type().id() == type_id::STRING)
if (result_type.id() == type_id::STRING)
dense_results->add_result(col_idx, agg, gather_argminmax(aggregation::ARGMIN));
else
dense_results->add_result(col_idx, agg, to_dense_agg_result(agg));
Expand All @@ -194,7 +198,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final
{
if (dense_results->has_result(col_idx, agg)) return;

if (col.type().id() == type_id::STRING)
if (result_type.id() == type_id::STRING)
dense_results->add_result(col_idx, agg, gather_argminmax(aggregation::ARGMAX));
else
dense_results->add_result(col_idx, agg, to_dense_agg_result(agg));
Expand All @@ -215,7 +219,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final
cudf::detail::binary_operation(sum_result,
count_result,
binary_operator::DIV,
cudf::detail::target_type(col.type(), aggregation::MEAN),
cudf::detail::target_type(result_type, aggregation::MEAN),
stream,
mr);
dense_results->add_result(col_idx, agg, std::move(result));
Expand All @@ -237,7 +241,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final
auto count_view = column_device_view::create(count_result);

auto var_result = make_fixed_width_column(
cudf::detail::target_type(col.type(), agg.kind), col.size(), mask_state::ALL_NULL, stream);
cudf::detail::target_type(result_type, agg.kind), col.size(), mask_state::ALL_NULL, stream);
auto var_result_view = mutable_column_device_view::create(var_result->mutable_view());
mutable_table_view var_table_view{{var_result->mutable_view()}};
cudf::detail::initialize_with_identity(var_table_view, {agg.kind}, stream);
Expand Down Expand Up @@ -285,11 +289,15 @@ flatten_single_pass_aggs(std::vector<aggregation_request> const& requests)
}
};

auto values_type = cudf::is_dictionary(request.values.type())
? cudf::dictionary_column_view(request.values).keys().type()
: request.values.type();
for (auto&& agg : agg_v) {
for (auto const& agg_s : agg->get_simple_aggregations(request.values.type()))
for (auto const& agg_s : agg->get_simple_aggregations(values_type))
insert_agg(i, request.values, agg_s);
}
}

return std::make_tuple(table_view(columns), std::move(agg_kinds), std::move(col_ids));
}

Expand Down Expand Up @@ -389,8 +397,12 @@ auto create_sparse_results_table(table_view const& flattened_values,
: (col.has_nulls() or agg == aggregation::VARIANCE or agg == aggregation::STD);
auto mask_flag = (nullable) ? mask_state::ALL_NULL : mask_state::UNALLOCATED;

auto col_type = cudf::is_dictionary(col.type())
? cudf::dictionary_column_view(col).keys().type()
: col.type();

return make_fixed_width_column(
cudf::detail::target_type(col.type(), agg), col.size(), mask_flag, stream);
cudf::detail::target_type(col_type, agg), col.size(), mask_flag, stream);
});

table sparse_table(std::move(sparse_columns));
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/groupby/hash/multi_pass_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ struct var_hash_functor {
}

template <typename Source>
__device__ std::enable_if_t<!is_supported<Source>()> operator()(size_type source_index,
__device__ std::enable_if_t<!is_supported<Source>()> operator()(column_device_view const& source,
size_type source_index,
size_type target_index) noexcept
{
release_assert(false and "Invalid source type for std, var aggregation combination.");
}

template <typename Source>
__device__ std::enable_if_t<is_supported<Source>()> operator()(size_type source_index,
__device__ std::enable_if_t<is_supported<Source>()> operator()(column_device_view const& source,
size_type source_index,
size_type target_index) noexcept
{
using Target = target_type_t<Source, aggregation::VARIANCE>;
Expand All @@ -92,7 +94,16 @@ struct var_hash_functor {
if (row_bitmask == nullptr or cudf::bit_is_set(row_bitmask, source_index)) {
auto result = map.find(source_index);
auto target_index = result->second;
type_dispatcher(source.type(), *this, source_index, target_index);

auto col = source;
auto source_type = source.type();
if (source_type.id() == type_id::DICTIONARY32) {
col = source.child(cudf::dictionary_column_view::keys_column_index);
source_type = col.type();
source_index = static_cast<size_type>(source.element<dictionary32>(source_index));
}

type_dispatcher(source_type, *this, col, source_index, target_index);
}
}
};
Expand Down
Loading

0 comments on commit 6828e2c

Please sign in to comment.