Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scan_aggregation and reduce_aggregation derived types. #10357

Merged
merged 10 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,28 @@ class groupby_scan_aggregation : public virtual aggregation {
groupby_scan_aggregation() {}
};

/**
* @brief Derived class intended for reduction usage.
*/
class reduce_aggregation : public virtual aggregation {
public:
~reduce_aggregation() override = default;

protected:
reduce_aggregation() {}
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
};

/**
* @brief Derived class intended for scan usage.
*/
class scan_aggregation : public virtual aggregation {
public:
~scan_aggregation() override = default;

protected:
scan_aggregation() {}
};

/**
* @brief Derived class intended for segmented reduction usage.
*/
Expand Down
57 changes: 40 additions & 17 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
class sum_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
sum_aggregation() : aggregation(SUM) {}
Expand All @@ -167,7 +169,10 @@ class sum_aggregation final : public rolling_aggregation,
/**
* @brief Derived class for specifying a product aggregation
*/
class product_aggregation final : public groupby_aggregation, public segmented_reduce_aggregation {
class product_aggregation final : public groupby_aggregation,
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
product_aggregation() : aggregation(PRODUCT) {}

Expand All @@ -189,6 +194,8 @@ class product_aggregation final : public groupby_aggregation, public segmented_r
class min_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
min_aggregation() : aggregation(MIN) {}
Expand All @@ -211,6 +218,8 @@ class min_aggregation final : public rolling_aggregation,
class max_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation,
public reduce_aggregation,
public scan_aggregation,
public segmented_reduce_aggregation {
public:
max_aggregation() : aggregation(MAX) {}
Expand Down Expand Up @@ -251,7 +260,7 @@ class count_aggregation final : public rolling_aggregation,
/**
* @brief Derived class for specifying an any aggregation
*/
class any_aggregation final : public segmented_reduce_aggregation {
class any_aggregation final : public reduce_aggregation, public segmented_reduce_aggregation {
public:
any_aggregation() : aggregation(ANY) {}

Expand All @@ -270,7 +279,7 @@ class any_aggregation final : public segmented_reduce_aggregation {
/**
* @brief Derived class for specifying an all aggregation
*/
class all_aggregation final : public segmented_reduce_aggregation {
class all_aggregation final : public reduce_aggregation, public segmented_reduce_aggregation {
public:
all_aggregation() : aggregation(ALL) {}

Expand All @@ -289,7 +298,7 @@ class all_aggregation final : public segmented_reduce_aggregation {
/**
* @brief Derived class for specifying a sum_of_squares aggregation
*/
class sum_of_squares_aggregation final : public groupby_aggregation {
class sum_of_squares_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
sum_of_squares_aggregation() : aggregation(SUM_OF_SQUARES) {}

Expand All @@ -308,7 +317,9 @@ class sum_of_squares_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a mean aggregation
*/
class mean_aggregation final : public rolling_aggregation, public groupby_aggregation {
class mean_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
mean_aggregation() : aggregation(MEAN) {}

Expand Down Expand Up @@ -346,7 +357,9 @@ class m2_aggregation : public groupby_aggregation {
/**
* @brief Derived class for specifying a standard deviation/variance aggregation
*/
class std_var_aggregation : public rolling_aggregation, public groupby_aggregation {
class std_var_aggregation : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
size_type _ddof; ///< Delta degrees of freedom

Expand Down Expand Up @@ -418,7 +431,7 @@ class std_aggregation final : public std_var_aggregation {
/**
* @brief Derived class for specifying a median aggregation
*/
class median_aggregation final : public groupby_aggregation {
class median_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
median_aggregation() : aggregation(MEDIAN) {}

Expand All @@ -437,7 +450,7 @@ class median_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a quantile aggregation
*/
class quantile_aggregation final : public groupby_aggregation {
class quantile_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
quantile_aggregation(std::vector<double> const& q, interpolation i)
: aggregation{QUANTILE}, _quantiles{q}, _interpolation{i}
Expand Down Expand Up @@ -524,7 +537,7 @@ class argmin_aggregation final : public rolling_aggregation, public groupby_aggr
/**
* @brief Derived class for specifying a nunique aggregation
*/
class nunique_aggregation final : public groupby_aggregation {
class nunique_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
nunique_aggregation(null_policy null_handling)
: aggregation{NUNIQUE}, _null_handling{null_handling}
Expand Down Expand Up @@ -563,7 +576,7 @@ class nunique_aggregation final : public groupby_aggregation {
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public groupby_aggregation {
class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down Expand Up @@ -625,7 +638,9 @@ class row_number_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a rank aggregation
*/
class rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
rank_aggregation() : aggregation{RANK} {}

Expand All @@ -644,7 +659,9 @@ class rank_aggregation final : public rolling_aggregation, public groupby_scan_a
/**
* @brief Derived class for specifying a dense rank aggregation
*/
class dense_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class dense_rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
dense_rank_aggregation() : aggregation{DENSE_RANK} {}

Expand All @@ -660,7 +677,9 @@ class dense_rank_aggregation final : public rolling_aggregation, public groupby_
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

class percent_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
class percent_rank_aggregation final : public rolling_aggregation,
public groupby_scan_aggregation,
public scan_aggregation {
public:
percent_rank_aggregation() : aggregation{PERCENT_RANK} {}

Expand All @@ -679,7 +698,9 @@ class percent_rank_aggregation final : public rolling_aggregation, public groupb
/**
* @brief Derived aggregation class for specifying COLLECT_LIST aggregation
*/
class collect_list_aggregation final : public rolling_aggregation, public groupby_aggregation {
class collect_list_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE)
: aggregation{COLLECT_LIST}, _null_handling{null_handling}
Expand Down Expand Up @@ -718,7 +739,9 @@ class collect_list_aggregation final : public rolling_aggregation, public groupb
/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
class collect_set_aggregation final : public rolling_aggregation, public groupby_aggregation {
class collect_set_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public reduce_aggregation {
public:
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
Expand Down Expand Up @@ -866,7 +889,7 @@ class udf_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying MERGE_LISTS aggregation
*/
class merge_lists_aggregation final : public groupby_aggregation {
class merge_lists_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {}

Expand All @@ -885,7 +908,7 @@ class merge_lists_aggregation final : public groupby_aggregation {
/**
* @brief Derived aggregation class for specifying MERGE_SETS aggregation
*/
class merge_sets_aggregation final : public groupby_aggregation {
class merge_sets_aggregation final : public groupby_aggregation, public reduce_aggregation {
public:
explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal)
: aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal)
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/detail/scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace detail {
* @returns Column with scan results.
*/
std::unique_ptr<column> scan_exclusive(column_view const& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
Expand All @@ -73,7 +73,7 @@ std::unique_ptr<column> scan_exclusive(column_view const& input,
* @returns Column with scan results.
*/
std::unique_ptr<column> scan_inclusive(column_view const& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cudf/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ enum class scan_type : bool { INCLUSIVE, EXCLUSIVE };
*/
std::unique_ptr<scalar> reduce(
column_view const& col,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<reduce_aggregation> const& agg,
data_type output_dtype,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down Expand Up @@ -142,7 +142,7 @@ std::unique_ptr<column> segmented_reduce(
*/
std::unique_ptr<column> scan(
const column_view& input,
std::unique_ptr<aggregation> const& agg,
std::unique_ptr<scan_aggregation> const& agg,
scan_type inclusive,
null_policy null_handling = null_policy::EXCLUDE,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
Expand Down
Loading