Skip to content

Commit

Permalink
Add groupby_aggregation and groupby_scan_aggregation classes and forc…
Browse files Browse the repository at this point in the history
…e their usage. (#8906)

Followup to #8052
Partially addresses #7106

Adds the `groupby_aggregation` class and forces usage of that type when calling `groupby::aggregate()`

Adds the `groupby_scan_aggregation` class and forces usage of that type when calling `groupby::scan()`

Authors:
  - https://github.com/nvdbaranec

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Jake Hemstad (https://github.com/jrhemstad)
  - David Wendt (https://github.com/davidwendt)
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Nghia Truong (https://github.com/ttnghia)
  - Devavret Makkar (https://github.com/devavret)

URL: #8906
  • Loading branch information
nvdbaranec authored Aug 19, 2021
1 parent 53f0bb4 commit 94c659c
Show file tree
Hide file tree
Showing 44 changed files with 1,105 additions and 402 deletions.
2 changes: 1 addition & 1 deletion cpp/benchmarks/groupby/group_nth_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void BM_pre_sorted_nth(benchmark::State& state)
std::vector<cudf::groupby::aggregation_request> requests;
requests.emplace_back(cudf::groupby::aggregation_request());
requests[0].values = vals;
requests[0].aggregations.push_back(cudf::make_nth_element_aggregation(-1));
requests[0].aggregations.push_back(cudf::make_nth_element_aggregation<groupby_aggregation>(-1));

for (auto _ : state) {
cuda_event_timer timer(state, true);
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/basic/src/process_csv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void write_csv(cudf::table_view const& tbl_view, std::string const& file_path)
}

std::vector<cudf::groupby::aggregation_request> make_single_aggregation_request(
std::unique_ptr<cudf::aggregation>&& agg, cudf::column_view value)
std::unique_ptr<cudf::groupby_aggregation>&& agg, cudf::column_view value)
{
std::vector<cudf::groupby::aggregation_request> requests;
requests.emplace_back(cudf::groupby::aggregation_request());
Expand All @@ -42,7 +42,7 @@ std::unique_ptr<cudf::table> average_closing_price(cudf::table_view stock_info_t

// Compute the average of each company's closing price with entire column
cudf::groupby::groupby grpby_obj(keys);
auto requests = make_single_aggregation_request(cudf::make_mean_aggregation(), val);
auto requests = make_single_aggregation_request(cudf::make_mean_aggregation<cudf::groupby_aggregation>(), val);

auto agg_results = grpby_obj.aggregate(requests);

Expand Down
25 changes: 23 additions & 2 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ class aggregation {
};

/**
* @brief Derived class intended for enforcing operation-specific restrictions
* when calling various cudf functions.
* @brief Derived class intended for rolling_window specific aggregation usage.
*
* As an example, rolling_window will only accept rolling_aggregation inputs,
* and the appropriate derived classes (sum_aggregation, mean_aggregation, etc)
Expand All @@ -121,6 +120,28 @@ class rolling_aggregation : public virtual aggregation {
rolling_aggregation() {}
};

/**
* @brief Derived class intended for groupby specific aggregation usage.
*/
class groupby_aggregation : public virtual aggregation {
public:
~groupby_aggregation() = default;

protected:
groupby_aggregation() {}
};

/**
* @brief Derived class intended for groupby specific scan usage.
*/
class groupby_scan_aggregation : public virtual aggregation {
public:
~groupby_scan_aggregation() = default;

protected:
groupby_scan_aggregation() {}
};

enum class udf_type : bool { CUDA, PTX };

/// Factory to create a SUM aggregation
Expand Down
63 changes: 38 additions & 25 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ class aggregation_finalizer { // Declares the interface for the finalizer
/**
* @brief Derived class for specifying a sum aggregation
*/
class sum_aggregation final : public rolling_aggregation {
class sum_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation {
public:
sum_aggregation() : aggregation(SUM) {}

Expand All @@ -149,7 +151,7 @@ class sum_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a product aggregation
*/
class product_aggregation final : public aggregation {
class product_aggregation final : public groupby_aggregation {
public:
product_aggregation() : aggregation(PRODUCT) {}

Expand All @@ -168,7 +170,9 @@ class product_aggregation final : public aggregation {
/**
* @brief Derived class for specifying a min aggregation
*/
class min_aggregation final : public rolling_aggregation {
class min_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation {
public:
min_aggregation() : aggregation(MIN) {}

Expand All @@ -187,7 +191,9 @@ class min_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a max aggregation
*/
class max_aggregation final : public rolling_aggregation {
class max_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation {
public:
max_aggregation() : aggregation(MAX) {}

Expand All @@ -206,7 +212,9 @@ class max_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a count aggregation
*/
class count_aggregation final : public rolling_aggregation {
class count_aggregation final : public rolling_aggregation,
public groupby_aggregation,
public groupby_scan_aggregation {
public:
count_aggregation(aggregation::Kind kind) : aggregation(kind) {}

Expand Down Expand Up @@ -263,7 +271,7 @@ class all_aggregation final : public aggregation {
/**
* @brief Derived class for specifying a sum_of_squares aggregation
*/
class sum_of_squares_aggregation final : public aggregation {
class sum_of_squares_aggregation final : public groupby_aggregation {
public:
sum_of_squares_aggregation() : aggregation(SUM_OF_SQUARES) {}

Expand All @@ -282,7 +290,7 @@ class sum_of_squares_aggregation final : public aggregation {
/**
* @brief Derived class for specifying a mean aggregation
*/
class mean_aggregation final : public rolling_aggregation {
class mean_aggregation final : public rolling_aggregation, public groupby_aggregation {
public:
mean_aggregation() : aggregation(MEAN) {}

Expand All @@ -301,7 +309,7 @@ class mean_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a m2 aggregation
*/
class m2_aggregation : public aggregation {
class m2_aggregation : public groupby_aggregation {
public:
m2_aggregation() : aggregation{M2} {}

Expand All @@ -320,7 +328,7 @@ class m2_aggregation : public aggregation {
/**
* @brief Derived class for specifying a standard deviation/variance aggregation
*/
class std_var_aggregation : public aggregation {
class std_var_aggregation : public groupby_aggregation {
public:
size_type _ddof; ///< Delta degrees of freedom

Expand All @@ -339,7 +347,6 @@ class std_var_aggregation : public aggregation {
CUDF_EXPECTS(k == aggregation::STD or k == aggregation::VARIANCE,
"std_var_aggregation can accept only STD, VARIANCE");
}

size_type hash_impl() const { return std::hash<size_type>{}(_ddof); }
};

Expand All @@ -348,7 +355,10 @@ class std_var_aggregation : public aggregation {
*/
class var_aggregation final : public std_var_aggregation {
public:
var_aggregation(size_type ddof) : std_var_aggregation{aggregation::VARIANCE, ddof} {}
var_aggregation(size_type ddof)
: aggregation{aggregation::VARIANCE}, std_var_aggregation{aggregation::VARIANCE, ddof}
{
}

std::unique_ptr<aggregation> clone() const override
{
Expand All @@ -367,7 +377,10 @@ class var_aggregation final : public std_var_aggregation {
*/
class std_aggregation final : public std_var_aggregation {
public:
std_aggregation(size_type ddof) : std_var_aggregation{aggregation::STD, ddof} {}
std_aggregation(size_type ddof)
: aggregation{aggregation::STD}, std_var_aggregation{aggregation::STD, ddof}
{
}

std::unique_ptr<aggregation> clone() const override
{
Expand All @@ -384,7 +397,7 @@ class std_aggregation final : public std_var_aggregation {
/**
* @brief Derived class for specifying a median aggregation
*/
class median_aggregation final : public aggregation {
class median_aggregation final : public groupby_aggregation {
public:
median_aggregation() : aggregation(MEDIAN) {}

Expand All @@ -403,7 +416,7 @@ class median_aggregation final : public aggregation {
/**
* @brief Derived class for specifying a quantile aggregation
*/
class quantile_aggregation final : public aggregation {
class quantile_aggregation final : public groupby_aggregation {
public:
quantile_aggregation(std::vector<double> const& q, interpolation i)
: aggregation{QUANTILE}, _quantiles{q}, _interpolation{i}
Expand Down Expand Up @@ -449,7 +462,7 @@ class quantile_aggregation final : public aggregation {
/**
* @brief Derived class for specifying an argmax aggregation
*/
class argmax_aggregation final : public rolling_aggregation {
class argmax_aggregation final : public rolling_aggregation, public groupby_aggregation {
public:
argmax_aggregation() : aggregation(ARGMAX) {}

Expand All @@ -468,7 +481,7 @@ class argmax_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying an argmin aggregation
*/
class argmin_aggregation final : public rolling_aggregation {
class argmin_aggregation final : public rolling_aggregation, public groupby_aggregation {
public:
argmin_aggregation() : aggregation(ARGMIN) {}

Expand All @@ -487,7 +500,7 @@ class argmin_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a nunique aggregation
*/
class nunique_aggregation final : public aggregation {
class nunique_aggregation final : public groupby_aggregation {
public:
nunique_aggregation(null_policy null_handling)
: aggregation{NUNIQUE}, _null_handling{null_handling}
Expand Down Expand Up @@ -523,7 +536,7 @@ class nunique_aggregation final : public aggregation {
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public aggregation {
class nth_element_aggregation final : public groupby_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down Expand Up @@ -582,7 +595,7 @@ class row_number_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a rank aggregation
*/
class rank_aggregation final : public rolling_aggregation {
class rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
public:
rank_aggregation() : aggregation{RANK} {}

Expand All @@ -601,7 +614,7 @@ class rank_aggregation final : public rolling_aggregation {
/**
* @brief Derived class for specifying a dense rank aggregation
*/
class dense_rank_aggregation final : public rolling_aggregation {
class dense_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation {
public:
dense_rank_aggregation() : aggregation{DENSE_RANK} {}

Expand All @@ -620,7 +633,7 @@ class dense_rank_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying COLLECT_LIST aggregation
*/
class collect_list_aggregation final : public rolling_aggregation {
class collect_list_aggregation final : public rolling_aggregation, public groupby_aggregation {
public:
explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE)
: aggregation{COLLECT_LIST}, _null_handling{null_handling}
Expand Down Expand Up @@ -656,7 +669,7 @@ class collect_list_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
class collect_set_aggregation final : public rolling_aggregation {
class collect_set_aggregation final : public rolling_aggregation, public groupby_aggregation {
public:
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
Expand Down Expand Up @@ -795,7 +808,7 @@ class udf_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying MERGE_LISTS aggregation
*/
class merge_lists_aggregation final : public aggregation {
class merge_lists_aggregation final : public groupby_aggregation {
public:
explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {}

Expand All @@ -814,7 +827,7 @@ class merge_lists_aggregation final : public aggregation {
/**
* @brief Derived aggregation class for specifying MERGE_SETS aggregation
*/
class merge_sets_aggregation final : public aggregation {
class merge_sets_aggregation final : public groupby_aggregation {
public:
explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal)
: aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal)
Expand Down Expand Up @@ -855,7 +868,7 @@ class merge_sets_aggregation final : public aggregation {
/**
* @brief Derived aggregation class for specifying MERGE_M2 aggregation
*/
class merge_m2_aggregation final : public aggregation {
class merge_m2_aggregation final : public groupby_aggregation {
public:
explicit merge_m2_aggregation() : aggregation{MERGE_M2} {}

Expand Down
23 changes: 19 additions & 4 deletions cpp/include/cudf/groupby.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,23 @@ class sort_groupby_helper;
* `values.size()` column must equal `keys.num_rows()`.
*/
struct aggregation_request {
column_view values; ///< The elements to aggregate
std::vector<std::unique_ptr<aggregation>> aggregations; ///< Desired aggregations
column_view values; ///< The elements to aggregate
std::vector<std::unique_ptr<groupby_aggregation>> aggregations; ///< Desired aggregations
};

/**
* @brief Request for groupby aggregation(s) for scanning a column.
*
* The group membership of each `value[i]` is determined by the corresponding
* row `i` in the original order of `keys` used to construct the
* `groupby`. I.e., for each `aggregation`, `values[i]` is aggregated with all
* other `values[j]` where rows `i` and `j` in `keys` are equivalent.
*
* `values.size()` column must equal `keys.num_rows()`.
*/
struct scan_request {
column_view values; ///< The elements to aggregate
std::vector<std::unique_ptr<groupby_scan_aggregation>> aggregations; ///< Desired aggregations
};

/**
Expand Down Expand Up @@ -222,7 +237,7 @@ class groupby {
* specified in `requests`.
*/
std::pair<std::unique_ptr<table>, std::vector<aggregation_result>> scan(
host_span<aggregation_request const> requests,
host_span<scan_request const> requests,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -388,7 +403,7 @@ class groupby {
rmm::mr::device_memory_resource* mr);

std::pair<std::unique_ptr<table>, std::vector<aggregation_result>> sort_scan(
host_span<aggregation_request const> requests,
host_span<scan_request const> requests,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
};
Expand Down
Loading

0 comments on commit 94c659c

Please sign in to comment.