Skip to content

Commit

Permalink
Call the specialized functions for struct type values
Browse files Browse the repository at this point in the history
  • Loading branch information
ttnghia committed Nov 4, 2021
1 parent d4d4644 commit ad7998f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 26 deletions.
48 changes: 32 additions & 16 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,45 @@ void aggregate_result_functor::operator()<aggregation::ARGMAX>(aggregation const
{
if (cache.has_result(values, agg)) return;

cache.add_result(values,
agg,
detail::group_argmax(get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr));
auto result = values.type().id() == type_id::STRUCT
? detail::group_argminmax_struct(aggregation::ARGMAX,
get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr)
: detail::group_argmax(get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr);

cache.add_result(values, agg, std::move(result));
};

template <>
void aggregate_result_functor::operator()<aggregation::ARGMIN>(aggregation const& agg)
{
if (cache.has_result(values, agg)) return;

cache.add_result(values,
agg,
detail::group_argmin(get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr));
auto result = values.type().id() == type_id::STRUCT
? detail::group_argminmax_struct(aggregation::ARGMIN,
get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr)
: detail::group_argmin(get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr);

cache.add_result(values, agg, std::move(result));
};

template <>
Expand Down
34 changes: 24 additions & 10 deletions cpp/src/groupby/sort/scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,37 @@ void scan_result_functor::operator()<aggregation::MIN>(aggregation const& agg)
{
if (cache.has_result(values, agg)) return;

cache.add_result(
values,
agg,
detail::min_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
auto result =
values.type().id() == type_id::STRUCT
? detail::minmax_scan_struct(aggregation::MIN,
get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
stream,
mr)
: detail::min_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr);

cache.add_result(values, agg, std::move(result));
}

template <>
void scan_result_functor::operator()<aggregation::MAX>(aggregation const& agg)
{
if (cache.has_result(values, agg)) return;

cache.add_result(
values,
agg,
detail::max_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
auto result =
values.type().id() == type_id::STRUCT
? detail::minmax_scan_struct(aggregation::MAX,
get_grouped_values(),
helper.num_groups(stream),
helper.group_labels(stream),
stream,
mr)
: detail::max_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr);

cache.add_result(values, agg, std::move(result));
}

template <>
Expand Down

0 comments on commit ad7998f

Please sign in to comment.