From ad7998fcdc4ee96d8ecc43cdb07f0e77071a4319 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 4 Nov 2021 13:30:18 -0600 Subject: [PATCH] Call the specialized functions for struct type values --- cpp/src/groupby/sort/aggregate.cpp | 48 ++++++++++++++++++++---------- cpp/src/groupby/sort/scan.cpp | 34 ++++++++++++++------- 2 files changed, 56 insertions(+), 26 deletions(-) diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 83c6c1bca57..7bce5a3e0fa 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -117,14 +117,22 @@ void aggregate_result_functor::operator()(aggregation const { if (cache.has_result(values, agg)) return; - cache.add_result(values, - agg, - detail::group_argmax(get_grouped_values(), - helper.num_groups(stream), - helper.group_labels(stream), - helper.key_sort_order(stream), - stream, - mr)); + auto result = values.type().id() == type_id::STRUCT + ? detail::group_argminmax_struct(aggregation::ARGMAX, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr) + : detail::group_argmax(get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr); + + cache.add_result(values, agg, std::move(result)); }; template <> @@ -132,14 +140,22 @@ void aggregate_result_functor::operator()(aggregation const { if (cache.has_result(values, agg)) return; - cache.add_result(values, - agg, - detail::group_argmin(get_grouped_values(), - helper.num_groups(stream), - helper.group_labels(stream), - helper.key_sort_order(stream), - stream, - mr)); + auto result = values.type().id() == type_id::STRUCT + ? detail::group_argminmax_struct(aggregation::ARGMIN, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr) + : detail::group_argmin(get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr); + + cache.add_result(values, agg, std::move(result)); }; template <> diff --git a/cpp/src/groupby/sort/scan.cpp b/cpp/src/groupby/sort/scan.cpp index b22f82ce7e4..eed6bd52faf 100644 --- a/cpp/src/groupby/sort/scan.cpp +++ b/cpp/src/groupby/sort/scan.cpp @@ -81,11 +81,18 @@ void scan_result_functor::operator()(aggregation const& agg) { if (cache.has_result(values, agg)) return; - cache.add_result( - values, - agg, - detail::min_scan( - get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr)); + auto result = + values.type().id() == type_id::STRUCT + ? detail::minmax_scan_struct(aggregation::MIN, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + stream, + mr) + : detail::min_scan( + get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr); + + cache.add_result(values, agg, std::move(result)); } template <> @@ -93,11 +100,18 @@ void scan_result_functor::operator()(aggregation const& agg) { if (cache.has_result(values, agg)) return; - cache.add_result( - values, - agg, - detail::max_scan( - get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr)); + auto result = + values.type().id() == type_id::STRUCT + ? detail::minmax_scan_struct(aggregation::MAX, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + stream, + mr) + : detail::max_scan( + get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr); + + cache.add_result(values, agg, std::move(result)); } template <>