From b1ea304ec7d421c871061fdfab7f2e097fbd4a0f Mon Sep 17 00:00:00 2001 From: nvdbaranec <56695930+nvdbaranec@users.noreply.github.com> Date: Fri, 11 Mar 2022 16:54:33 -0600 Subject: [PATCH] Add scan_aggregation and reduce_aggregation derived types. (#10357) This PR adds the `scan_aggregation` and `reduce_aggregation` derived types. With it, all concrete aggregation types are now derived from algorithmic specific subtypes. Authors: - https://github.com/nvdbaranec Approvers: - Robert (Bobby) Evans (https://github.com/revans2) - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/10357 --- cpp/include/cudf/aggregation.hpp | 22 + .../cudf/detail/aggregation/aggregation.hpp | 57 +- cpp/include/cudf/detail/scan.hpp | 4 +- cpp/include/cudf/reduction.hpp | 4 +- cpp/src/aggregation/aggregation.cpp | 33 + cpp/src/reductions/reductions.cpp | 6 +- cpp/src/reductions/scan/scan.cpp | 2 +- cpp/src/reductions/scan/scan.cuh | 4 +- cpp/src/reductions/scan/scan_exclusive.cu | 2 +- cpp/src/reductions/scan/scan_inclusive.cu | 2 +- cpp/tests/reductions/collect_ops_tests.cpp | 109 +-- cpp/tests/reductions/rank_tests.cpp | 8 +- cpp/tests/reductions/reduction_tests.cpp | 682 +++++++++++------- cpp/tests/reductions/scan_tests.cpp | 303 +++++--- java/src/main/native/src/ColumnViewJni.cpp | 12 +- python/cudf/cudf/_lib/aggregation.pxd | 17 +- python/cudf/cudf/_lib/aggregation.pyx | 572 +++++++-------- python/cudf/cudf/_lib/cpp/aggregation.pxd | 6 + python/cudf/cudf/_lib/cpp/reduce.pxd | 8 +- python/cudf/cudf/_lib/reduce.pyx | 14 +- 20 files changed, 1103 insertions(+), 764 deletions(-) diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index 2a4b86588cb..539a7c04106 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -148,6 +148,28 @@ class groupby_scan_aggregation : public virtual aggregation { groupby_scan_aggregation() {} }; +/** + * @brief Derived class intended for reduction usage. + */ +class reduce_aggregation : public virtual aggregation { + public: + ~reduce_aggregation() override = default; + + protected: + reduce_aggregation() {} +}; + +/** + * @brief Derived class intended for scan usage. + */ +class scan_aggregation : public virtual aggregation { + public: + ~scan_aggregation() override = default; + + protected: + scan_aggregation() {} +}; + /** * @brief Derived class intended for segmented reduction usage. */ diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index cb30a54fe84..eba24dd2d13 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -148,6 +148,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer class sum_aggregation final : public rolling_aggregation, public groupby_aggregation, public groupby_scan_aggregation, + public reduce_aggregation, + public scan_aggregation, public segmented_reduce_aggregation { public: sum_aggregation() : aggregation(SUM) {} @@ -167,7 +169,10 @@ class sum_aggregation final : public rolling_aggregation, /** * @brief Derived class for specifying a product aggregation */ -class product_aggregation final : public groupby_aggregation, public segmented_reduce_aggregation { +class product_aggregation final : public groupby_aggregation, + public reduce_aggregation, + public scan_aggregation, + public segmented_reduce_aggregation { public: product_aggregation() : aggregation(PRODUCT) {} @@ -189,6 +194,8 @@ class product_aggregation final : public groupby_aggregation, public segmented_r class min_aggregation final : public rolling_aggregation, public groupby_aggregation, public groupby_scan_aggregation, + public reduce_aggregation, + public scan_aggregation, public segmented_reduce_aggregation { public: min_aggregation() : aggregation(MIN) {} @@ -211,6 +218,8 @@ class min_aggregation final : public rolling_aggregation, class max_aggregation final : public rolling_aggregation, public groupby_aggregation, public groupby_scan_aggregation, + public reduce_aggregation, + public scan_aggregation, public segmented_reduce_aggregation { public: max_aggregation() : aggregation(MAX) {} @@ -251,7 +260,7 @@ class count_aggregation final : public rolling_aggregation, /** * @brief Derived class for specifying an any aggregation */ -class any_aggregation final : public segmented_reduce_aggregation { +class any_aggregation final : public reduce_aggregation, public segmented_reduce_aggregation { public: any_aggregation() : aggregation(ANY) {} @@ -270,7 +279,7 @@ class any_aggregation final : public segmented_reduce_aggregation { /** * @brief Derived class for specifying an all aggregation */ -class all_aggregation final : public segmented_reduce_aggregation { +class all_aggregation final : public reduce_aggregation, public segmented_reduce_aggregation { public: all_aggregation() : aggregation(ALL) {} @@ -289,7 +298,7 @@ class all_aggregation final : public segmented_reduce_aggregation { /** * @brief Derived class for specifying a sum_of_squares aggregation */ -class sum_of_squares_aggregation final : public groupby_aggregation { +class sum_of_squares_aggregation final : public groupby_aggregation, public reduce_aggregation { public: sum_of_squares_aggregation() : aggregation(SUM_OF_SQUARES) {} @@ -308,7 +317,9 @@ class sum_of_squares_aggregation final : public groupby_aggregation { /** * @brief Derived class for specifying a mean aggregation */ -class mean_aggregation final : public rolling_aggregation, public groupby_aggregation { +class mean_aggregation final : public rolling_aggregation, + public groupby_aggregation, + public reduce_aggregation { public: mean_aggregation() : aggregation(MEAN) {} @@ -346,7 +357,9 @@ class m2_aggregation : public groupby_aggregation { /** * @brief Derived class for specifying a standard deviation/variance aggregation */ -class std_var_aggregation : public rolling_aggregation, public groupby_aggregation { +class std_var_aggregation : public rolling_aggregation, + public groupby_aggregation, + public reduce_aggregation { public: size_type _ddof; ///< Delta degrees of freedom @@ -418,7 +431,7 @@ class std_aggregation final : public std_var_aggregation { /** * @brief Derived class for specifying a median aggregation */ -class median_aggregation final : public groupby_aggregation { +class median_aggregation final : public groupby_aggregation, public reduce_aggregation { public: median_aggregation() : aggregation(MEDIAN) {} @@ -437,7 +450,7 @@ class median_aggregation final : public groupby_aggregation { /** * @brief Derived class for specifying a quantile aggregation */ -class quantile_aggregation final : public groupby_aggregation { +class quantile_aggregation final : public groupby_aggregation, public reduce_aggregation { public: quantile_aggregation(std::vector const& q, interpolation i) : aggregation{QUANTILE}, _quantiles{q}, _interpolation{i} @@ -524,7 +537,7 @@ class argmin_aggregation final : public rolling_aggregation, public groupby_aggr /** * @brief Derived class for specifying a nunique aggregation */ -class nunique_aggregation final : public groupby_aggregation { +class nunique_aggregation final : public groupby_aggregation, public reduce_aggregation { public: nunique_aggregation(null_policy null_handling) : aggregation{NUNIQUE}, _null_handling{null_handling} @@ -563,7 +576,7 @@ class nunique_aggregation final : public groupby_aggregation { /** * @brief Derived class for specifying a nth element aggregation */ -class nth_element_aggregation final : public groupby_aggregation { +class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation { public: nth_element_aggregation(size_type n, null_policy null_handling) : aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling} @@ -625,7 +638,9 @@ class row_number_aggregation final : public rolling_aggregation { /** * @brief Derived class for specifying a rank aggregation */ -class rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation { +class rank_aggregation final : public rolling_aggregation, + public groupby_scan_aggregation, + public scan_aggregation { public: rank_aggregation() : aggregation{RANK} {} @@ -644,7 +659,9 @@ class rank_aggregation final : public rolling_aggregation, public groupby_scan_a /** * @brief Derived class for specifying a dense rank aggregation */ -class dense_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation { +class dense_rank_aggregation final : public rolling_aggregation, + public groupby_scan_aggregation, + public scan_aggregation { public: dense_rank_aggregation() : aggregation{DENSE_RANK} {} @@ -660,7 +677,9 @@ class dense_rank_aggregation final : public rolling_aggregation, public groupby_ void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; -class percent_rank_aggregation final : public rolling_aggregation, public groupby_scan_aggregation { +class percent_rank_aggregation final : public rolling_aggregation, + public groupby_scan_aggregation, + public scan_aggregation { public: percent_rank_aggregation() : aggregation{PERCENT_RANK} {} @@ -679,7 +698,9 @@ class percent_rank_aggregation final : public rolling_aggregation, public groupb /** * @brief Derived aggregation class for specifying COLLECT_LIST aggregation */ -class collect_list_aggregation final : public rolling_aggregation, public groupby_aggregation { +class collect_list_aggregation final : public rolling_aggregation, + public groupby_aggregation, + public reduce_aggregation { public: explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE) : aggregation{COLLECT_LIST}, _null_handling{null_handling} @@ -718,7 +739,9 @@ class collect_list_aggregation final : public rolling_aggregation, public groupb /** * @brief Derived aggregation class for specifying COLLECT_SET aggregation */ -class collect_set_aggregation final : public rolling_aggregation, public groupby_aggregation { +class collect_set_aggregation final : public rolling_aggregation, + public groupby_aggregation, + public reduce_aggregation { public: explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE, null_equality nulls_equal = null_equality::EQUAL, @@ -866,7 +889,7 @@ class udf_aggregation final : public rolling_aggregation { /** * @brief Derived aggregation class for specifying MERGE_LISTS aggregation */ -class merge_lists_aggregation final : public groupby_aggregation { +class merge_lists_aggregation final : public groupby_aggregation, public reduce_aggregation { public: explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {} @@ -885,7 +908,7 @@ class merge_lists_aggregation final : public groupby_aggregation { /** * @brief Derived aggregation class for specifying MERGE_SETS aggregation */ -class merge_sets_aggregation final : public groupby_aggregation { +class merge_sets_aggregation final : public groupby_aggregation, public reduce_aggregation { public: explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal) : aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal) diff --git a/cpp/include/cudf/detail/scan.hpp b/cpp/include/cudf/detail/scan.hpp index 36dce6caf0b..fc829617c2d 100644 --- a/cpp/include/cudf/detail/scan.hpp +++ b/cpp/include/cudf/detail/scan.hpp @@ -47,7 +47,7 @@ namespace detail { * @returns Column with scan results. */ std::unique_ptr scan_exclusive(column_view const& input, - std::unique_ptr const& agg, + std::unique_ptr const& agg, null_policy null_handling, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); @@ -73,7 +73,7 @@ std::unique_ptr scan_exclusive(column_view const& input, * @returns Column with scan results. */ std::unique_ptr scan_inclusive(column_view const& input, - std::unique_ptr const& agg, + std::unique_ptr const& agg, null_policy null_handling, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); diff --git a/cpp/include/cudf/reduction.hpp b/cpp/include/cudf/reduction.hpp index 367814cda8e..f140ba7d4a9 100644 --- a/cpp/include/cudf/reduction.hpp +++ b/cpp/include/cudf/reduction.hpp @@ -68,7 +68,7 @@ enum class scan_type : bool { INCLUSIVE, EXCLUSIVE }; */ std::unique_ptr reduce( column_view const& col, - std::unique_ptr const& agg, + std::unique_ptr const& agg, data_type output_dtype, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -142,7 +142,7 @@ std::unique_ptr segmented_reduce( */ std::unique_ptr scan( const column_view& input, - std::unique_ptr const& agg, + std::unique_ptr const& agg, scan_type inclusive, null_policy null_handling = null_policy::EXCLUDE, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 68fe6d86103..9405b4c37ac 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -417,6 +417,8 @@ template std::unique_ptr make_sum_aggregation(); template std::unique_ptr make_sum_aggregation(); template std::unique_ptr make_sum_aggregation(); template std::unique_ptr make_sum_aggregation(); +template std::unique_ptr make_sum_aggregation(); +template std::unique_ptr make_sum_aggregation(); template std::unique_ptr make_sum_aggregation(); @@ -428,6 +430,8 @@ std::unique_ptr make_product_aggregation() } template std::unique_ptr make_product_aggregation(); template std::unique_ptr make_product_aggregation(); +template std::unique_ptr make_product_aggregation(); +template std::unique_ptr make_product_aggregation(); template std::unique_ptr make_product_aggregation(); @@ -441,6 +445,8 @@ template std::unique_ptr make_min_aggregation(); template std::unique_ptr make_min_aggregation(); template std::unique_ptr make_min_aggregation(); template std::unique_ptr make_min_aggregation(); +template std::unique_ptr make_min_aggregation(); +template std::unique_ptr make_min_aggregation(); template std::unique_ptr make_min_aggregation(); @@ -454,6 +460,8 @@ template std::unique_ptr make_max_aggregation(); template std::unique_ptr make_max_aggregation(); template std::unique_ptr make_max_aggregation(); template std::unique_ptr make_max_aggregation(); +template std::unique_ptr make_max_aggregation(); +template std::unique_ptr make_max_aggregation(); template std::unique_ptr make_max_aggregation(); @@ -481,6 +489,7 @@ std::unique_ptr make_any_aggregation() return std::make_unique(); } template std::unique_ptr make_any_aggregation(); +template std::unique_ptr make_any_aggregation(); template std::unique_ptr make_any_aggregation(); @@ -491,6 +500,7 @@ std::unique_ptr make_all_aggregation() return std::make_unique(); } template std::unique_ptr make_all_aggregation(); +template std::unique_ptr make_all_aggregation(); template std::unique_ptr make_all_aggregation(); @@ -503,6 +513,7 @@ std::unique_ptr make_sum_of_squares_aggregation() template std::unique_ptr make_sum_of_squares_aggregation(); template std::unique_ptr make_sum_of_squares_aggregation(); +template std::unique_ptr make_sum_of_squares_aggregation(); /// Factory to create a MEAN aggregation template @@ -513,6 +524,7 @@ std::unique_ptr make_mean_aggregation() template std::unique_ptr make_mean_aggregation(); template std::unique_ptr make_mean_aggregation(); template std::unique_ptr make_mean_aggregation(); +template std::unique_ptr make_mean_aggregation(); /// Factory to create a M2 aggregation template @@ -534,6 +546,8 @@ template std::unique_ptr make_variance_aggregation make_variance_aggregation( size_type ddof); +template std::unique_ptr make_variance_aggregation( + size_type ddof); /// Factory to create a STD aggregation template @@ -546,6 +560,8 @@ template std::unique_ptr make_std_aggregation make_std_aggregation( size_type ddof); +template std::unique_ptr make_std_aggregation( + size_type ddof); /// Factory to create a MEDIAN aggregation template @@ -555,6 +571,7 @@ std::unique_ptr make_median_aggregation() } template std::unique_ptr make_median_aggregation(); template std::unique_ptr make_median_aggregation(); +template std::unique_ptr make_median_aggregation(); /// Factory to create a QUANTILE aggregation template @@ -567,6 +584,8 @@ template std::unique_ptr make_quantile_aggregation( std::vector const& quantiles, interpolation interp); template std::unique_ptr make_quantile_aggregation( std::vector const& quantiles, interpolation interp); +template std::unique_ptr make_quantile_aggregation( + std::vector const& quantiles, interpolation interp); /// Factory to create an ARGMAX aggregation template @@ -598,6 +617,8 @@ template std::unique_ptr make_nunique_aggregation( null_policy null_handling); template std::unique_ptr make_nunique_aggregation( null_policy null_handling); +template std::unique_ptr make_nunique_aggregation( + null_policy null_handling); /// Factory to create an NTH_ELEMENT aggregation template @@ -609,6 +630,8 @@ template std::unique_ptr make_nth_element_aggregation( size_type n, null_policy null_handling); template std::unique_ptr make_nth_element_aggregation( size_type n, null_policy null_handling); +template std::unique_ptr make_nth_element_aggregation( + size_type n, null_policy null_handling); /// Factory to create a ROW_NUMBER aggregation template @@ -628,6 +651,7 @@ std::unique_ptr make_rank_aggregation() template std::unique_ptr make_rank_aggregation(); template std::unique_ptr make_rank_aggregation(); +template std::unique_ptr make_rank_aggregation(); /// Factory to create a DENSE_RANK aggregation template @@ -638,6 +662,7 @@ std::unique_ptr make_dense_rank_aggregation() template std::unique_ptr make_dense_rank_aggregation(); template std::unique_ptr make_dense_rank_aggregation(); +template std::unique_ptr make_dense_rank_aggregation(); /// Factory to create a PERCENT_RANK aggregation template @@ -648,6 +673,7 @@ std::unique_ptr make_percent_rank_aggregation() template std::unique_ptr make_percent_rank_aggregation(); template std::unique_ptr make_percent_rank_aggregation(); +template std::unique_ptr make_percent_rank_aggregation(); /// Factory to create a COLLECT_LIST aggregation template @@ -661,6 +687,8 @@ template std::unique_ptr make_collect_list_aggregation make_collect_list_aggregation( null_policy null_handling); +template std::unique_ptr make_collect_list_aggregation( + null_policy null_handling); /// Factory to create a COLLECT_SET aggregation template @@ -676,6 +704,8 @@ template std::unique_ptr make_collect_set_aggregation make_collect_set_aggregation( null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal); +template std::unique_ptr make_collect_set_aggregation( + null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal); /// Factory to create a LAG aggregation template @@ -722,6 +752,7 @@ std::unique_ptr make_merge_lists_aggregation() } template std::unique_ptr make_merge_lists_aggregation(); template std::unique_ptr make_merge_lists_aggregation(); +template std::unique_ptr make_merge_lists_aggregation(); /// Factory to create a MERGE_SETS aggregation template @@ -734,6 +765,8 @@ template std::unique_ptr make_merge_sets_aggregation(n nan_equality); template std::unique_ptr make_merge_sets_aggregation( null_equality, nan_equality); +template std::unique_ptr make_merge_sets_aggregation( + null_equality, nan_equality); /// Factory to create a MERGE_M2 aggregation template diff --git a/cpp/src/reductions/reductions.cpp b/cpp/src/reductions/reductions.cpp index 13574f83d4e..bd8c8342708 100644 --- a/cpp/src/reductions/reductions.cpp +++ b/cpp/src/reductions/reductions.cpp @@ -45,7 +45,7 @@ struct reduce_dispatch_functor { } template - std::unique_ptr operator()(std::unique_ptr const& agg) + std::unique_ptr operator()(std::unique_ptr const& agg) { switch (k) { case aggregation::SUM: return reduction::sum(col, output_dtype, stream, mr); break; @@ -125,7 +125,7 @@ struct reduce_dispatch_functor { std::unique_ptr reduce( column_view const& col, - std::unique_ptr const& agg, + std::unique_ptr const& agg, data_type output_dtype, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) @@ -145,7 +145,7 @@ std::unique_ptr reduce( } // namespace detail std::unique_ptr reduce(column_view const& col, - std::unique_ptr const& agg, + std::unique_ptr const& agg, data_type output_dtype, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/reductions/scan/scan.cpp b/cpp/src/reductions/scan/scan.cpp index d73fc862380..52aaad5ddcf 100644 --- a/cpp/src/reductions/scan/scan.cpp +++ b/cpp/src/reductions/scan/scan.cpp @@ -25,7 +25,7 @@ namespace cudf { std::unique_ptr scan(column_view const& input, - std::unique_ptr const& agg, + std::unique_ptr const& agg, scan_type inclusive, null_policy null_handling, rmm::mr::device_memory_resource* mr) diff --git a/cpp/src/reductions/scan/scan.cuh b/cpp/src/reductions/scan/scan.cuh index 84387aba914..127f2ae95b4 100644 --- a/cpp/src/reductions/scan/scan.cuh +++ b/cpp/src/reductions/scan/scan.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ rmm::device_buffer mask_scan(column_view const& input_view, template