diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 83c6c1bca57..7bce5a3e0fa 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -117,14 +117,22 @@ void aggregate_result_functor::operator()(aggregation const { if (cache.has_result(values, agg)) return; - cache.add_result(values, - agg, - detail::group_argmax(get_grouped_values(), - helper.num_groups(stream), - helper.group_labels(stream), - helper.key_sort_order(stream), - stream, - mr)); + auto result = values.type().id() == type_id::STRUCT + ? detail::group_argminmax_struct(aggregation::ARGMAX, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr) + : detail::group_argmax(get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr); + + cache.add_result(values, agg, std::move(result)); }; template <> @@ -132,14 +140,22 @@ void aggregate_result_functor::operator()(aggregation const { if (cache.has_result(values, agg)) return; - cache.add_result(values, - agg, - detail::group_argmin(get_grouped_values(), - helper.num_groups(stream), - helper.group_labels(stream), - helper.key_sort_order(stream), - stream, - mr)); + auto result = values.type().id() == type_id::STRUCT + ? detail::group_argminmax_struct(aggregation::ARGMIN, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr) + : detail::group_argmin(get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + helper.key_sort_order(stream), + stream, + mr); + + cache.add_result(values, agg, std::move(result)); }; template <> diff --git a/cpp/src/groupby/sort/scan.cpp b/cpp/src/groupby/sort/scan.cpp index b22f82ce7e4..eed6bd52faf 100644 --- a/cpp/src/groupby/sort/scan.cpp +++ b/cpp/src/groupby/sort/scan.cpp @@ -81,11 +81,18 @@ void scan_result_functor::operator()(aggregation const& agg) { if (cache.has_result(values, agg)) return; - cache.add_result( - values, - agg, - detail::min_scan( - get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr)); + auto result = + values.type().id() == type_id::STRUCT + ? detail::minmax_scan_struct(aggregation::MIN, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + stream, + mr) + : detail::min_scan( + get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr); + + cache.add_result(values, agg, std::move(result)); } template <> @@ -93,11 +100,18 @@ void scan_result_functor::operator()(aggregation const& agg) { if (cache.has_result(values, agg)) return; - cache.add_result( - values, - agg, - detail::max_scan( - get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr)); + auto result = + values.type().id() == type_id::STRUCT + ? detail::minmax_scan_struct(aggregation::MAX, + get_grouped_values(), + helper.num_groups(stream), + helper.group_labels(stream), + stream, + mr) + : detail::max_scan( + get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr); + + cache.add_result(values, agg, std::move(result)); } template <>