Skip to content

Commit

Permalink
Add scan_aggregation and reduce_aggregation derived types. (#10357)
Browse files Browse the repository at this point in the history
This PR adds the `scan_aggregation` and `reduce_aggregation` derived types. With it, all concrete aggregation types are now derived from algorithmic specific subtypes.

Authors:
  - https://github.com/nvdbaranec

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu)
  - Bradley Dice (https://github.com/bdice)

URL: #10357
  • Loading branch information
nvdbaranec authored Mar 11, 2022
1 parent f29c8d9 commit b1ea304
Show file tree
Hide file tree
Showing 20 changed files with 1,103 additions and 764 deletions.
22 changes: 22 additions & 0 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,28 @@ class groupby_scan_aggregation : public virtual aggregation {
groupby_scan_aggregation() {}
};

/**
* @brief Derived class intended for reduction usage.
*/
class reduce_aggregation : public virtual aggregation {
public:
~reduce_aggregation() override = default;

protected:
reduce_aggregation() {}
};

/**
* @brief Derived class intended for scan usage.
*/
class scan_aggregation : public virtual aggregation {
public:
~scan_aggregation() override = default;

protected:
scan_aggregation() {}
};

/**
* @brief Derived class intended for segmented reduction usage.
*/
Expand Down
57 changes: 40 additions & 17 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
class sum_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
sum_aggregation() : aggregation(SUM) {}
Expand All @@ -167,7 +169,10 @@ class sum_aggregation final : public rolling_aggregation,
/**
* @brief Derived class for specifying a product aggregation
*/
class product_aggregation final : public groupby_aggregation, public segmented_reduce_aggregation {
class product_aggregation final : public groupby_aggregation,
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
product_aggregation() : aggregation(PRODUCT) {}

Expand All @@ -189,6 +194,8 @@ class product_aggregation final : public groupby_aggregation, public segmented_r
class min_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
min_aggregation() : aggregation(MIN) {}
Expand All @@ -211,6 +218,8 @@ class min_aggregation final : public rolling_aggregation,
class max_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
max_aggregation() : aggregation(MAX) {}
Expand Down Expand Up @@ -251,7 +260,7 @@ class count_aggregation final : public rolling_aggregation,
/**
* @brief Derived class for specifying an any aggregation
*/
class any_aggregation final : public segmented_reduce_aggregation {
class any_aggregation final : public reduce_aggregation, public segmented_reduce_aggregation {
public:
any_aggregation() : aggregation(ANY) {}

Expand All @@ -270,7 +279,7 @@ class any_aggregation final : public segmented_reduce_aggregation {
/**
* @brief Derived class for specifying an all aggregation
*/
class all_aggregation final : public segmented_reduce_aggregation {
class all_aggregation final : public reduce_aggregation, public segmented_reduce_aggregation {
public:
all_aggregation() : aggregation(ALL) {}

Expand All @@ -289,7 +298,7 @@ class all_aggregation final : public segmented_reduce_aggregation {
/**
* @brief Derived class for specifying a sum_of_squares aggregation
*/
class sum_of_squares_aggregation final : public groupby_aggregation {
class sum_of_squares_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
sum_of_squares_aggregation() : aggregation(SUM_OF_SQUARES) {}

Expand All @@ -308,7 +317,9 @@ class sum_of_squares_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a mean aggregation
*/
class mean_aggregation final : public rolling_aggregation, public groupby_aggregation {
class mean_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
mean_aggregation() : aggregation(MEAN) {}

Expand Down Expand Up @@ -346,7 +357,9 @@ class m2_aggregation : public groupby_aggregation {
/**
* @brief Derived class for specifying a standard deviation/variance aggregation
*/
class std_var_aggregation : public rolling_aggregation, public groupby_aggregation {
class std_var_aggregation : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
size_type _ddof; ///< Delta degrees of freedom

Expand Down Expand Up @@ -418,7 +431,7 @@ class std_aggregation final : public std_var_aggregation {
/**
* @brief Derived class for specifying a median aggregation
*/
class median_aggregation final : public groupby_aggregation {
class median_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
median_aggregation() : aggregation(MEDIAN) {}

Expand All @@ -437,7 +450,7 @@ class median_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a quantile aggregation
*/
class quantile_aggregation final : public groupby_aggregation {
class quantile_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
quantile_aggregation(std::vector<double> const& q, interpolation i)
: aggregation{QUANTILE}, _quantiles{q}, _interpolation{i}
Expand Down Expand Up @@ -524,7 +537,7 @@ class argmin_aggregation final : public rolling_aggregation, public groupby_aggr
/**
* @brief Derived class for specifying a nunique aggregation
*/
class nunique_aggregation final : public groupby_aggregation {
class nunique_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
nunique_aggregation(null_policy null_handling)
: aggregation{NUNIQUE}, _null_handling{null_handling}
Expand Down Expand Up @@ -563,7 +576,7 @@ class nunique_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public groupby_aggregation {
class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down Expand Up @@ -625,7 +638,9 @@ class row_number_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a rank aggregation
*/
class rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
rank_aggregation() : aggregation{RANK} {}

Expand All @@ -644,7 +659,9 @@ class rank_aggregation final : public rolling_aggregation, public groupby_scan_a
/**
* @brief Derived class for specifying a dense rank aggregation
*/
class dense_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class dense_rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
dense_rank_aggregation() : aggregation{DENSE_RANK} {}

Expand All @@ -660,7 +677,9 @@ class dense_rank_aggregation final : public rolling_aggregation, public groupby_
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

class percent_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class percent_rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
percent_rank_aggregation() : aggregation{PERCENT_RANK} {}

Expand All @@ -679,7 +698,9 @@ class percent_rank_aggregation final : public rolling_aggregation, public groupb
/**
* @brief Derived aggregation class for specifying COLLECT_LIST aggregation
*/
class collect_list_aggregation final : public rolling_aggregation, public groupby_aggregation {
class collect_list_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE)
: aggregation{COLLECT_LIST}, _null_handling{null_handling}
Expand Down Expand Up @@ -718,7 +739,9 @@ class collect_list_aggregation final : public rolling_aggregation, public groupb
/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
class collect_set_aggregation final : public rolling_aggregation, public groupby_aggregation {
class collect_set_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
Expand Down Expand Up @@ -866,7 +889,7 @@ class udf_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying MERGE_LISTS aggregation
*/
class merge_lists_aggregation final : public groupby_aggregation {
class merge_lists_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {}

Expand All @@ -885,7 +908,7 @@ class merge_lists_aggregation final : public groupby_aggregation {
/**
* @brief Derived aggregation class for specifying MERGE_SETS aggregation
*/
class merge_sets_aggregation final : public groupby_aggregation {
class merge_sets_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal)
: aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal)
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/detail/scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace detail {
* @returns Column with scan results.
*/
std::unique_ptr<column> scan_exclusive(column_view const& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
Expand All @@ -73,7 +73,7 @@ std::unique_ptr<column> scan_exclusive(column_view const& input,
* @returns Column with scan results.
*/
std::unique_ptr<column> scan_inclusive(column_view const& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ enum class scan_type : bool { INCLUSIVE, EXCLUSIVE };
*/
std::unique_ptr<scalar> reduce(
column_view const& col,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<reduce_aggregation> const& agg,
data_type output_dtype,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -142,7 +142,7 @@ std::unique_ptr<column> segmented_reduce(
*/
std::unique_ptr<column> scan(
const column_view& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
scan_type inclusive,
null_policy null_handling = null_policy::EXCLUDE,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
Expand Down
Loading

0 comments on commit b1ea304

Please sign in to comment.