Skip to content

Commit

Permalink
Fix SUM/MEAN aggregation type support. (#12503)
Browse files Browse the repository at this point in the history
This PR closes #8399. We cleaned up the logic by fixing SUM/MEAN aggregation type support, which also eliminated `TODO` comments in the target type definitions.

We kept the restriction that rolling min/max requires fixed width types because min/max aggregations do support non-fixed width in other aggregation implementations (groupby does a argmin-and-gather approach on strings, for instance).

This PR is collaborative work with @karthikeyann.

Authors:
  - Bradley Dice (https://github.com/bdice)
  - Karthikeyan (https://github.com/karthikeyann)

Approvers:
  - Mark Harris (https://github.com/harrism)
  - David Wendt (https://github.com/davidwendt)

URL: #12503
  • Loading branch information
bdice authored Jan 24, 2023
1 parent 8e9ccc3 commit 2bcdb54
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 26 deletions.
6 changes: 3 additions & 3 deletions cpp/include/cudf/detail/aggregation/aggregation.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -232,8 +232,8 @@ struct update_target_element<
aggregation::SUM,
target_has_nulls,
source_has_nulls,
std::enable_if_t<is_fixed_width<Source>() && cudf::has_atomic_support<Source>() &&
!is_fixed_point<Source>()>> {
std::enable_if_t<cudf::is_fixed_width<Source>() && cudf::has_atomic_support<Source>() &&
!cudf::is_fixed_point<Source>() && !cudf::is_timestamp<Source>()>> {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
Expand Down
21 changes: 10 additions & 11 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1154,9 +1154,7 @@ struct target_type_impl<Source, aggregation::ALL> {
using type = bool;
};

// Always use `double` for MEAN
// Except for chrono types where result is chrono. (Use FloorDiv)
// TODO: MEAN should be only be enabled for duration types - not for timestamps
// Always use `double` for MEAN except for durations and fixed point types.
template <typename Source, aggregation::Kind k>
struct target_type_impl<
Source,
Expand All @@ -1167,10 +1165,10 @@ struct target_type_impl<
};

template <typename Source, aggregation::Kind k>
struct target_type_impl<
Source,
k,
std::enable_if_t<(is_chrono<Source>() or is_fixed_point<Source>()) && (k == aggregation::MEAN)>> {
struct target_type_impl<Source,
k,
std::enable_if_t<(is_duration<Source>() or is_fixed_point<Source>()) &&
(k == aggregation::MEAN)>> {
using type = Source;
};

Expand Down Expand Up @@ -1206,10 +1204,11 @@ struct target_type_impl<
using type = Source;
};

// Summing/Multiplying chrono types, use same type accumulator
// TODO: Sum/Product should only be enabled for duration types - not for timestamps
// Summing duration types, use same type accumulator
template <typename Source, aggregation::Kind k>
struct target_type_impl<Source, k, std::enable_if_t<is_chrono<Source>() && is_sum_product_agg(k)>> {
struct target_type_impl<Source,
k,
std::enable_if_t<is_duration<Source>() && (k == aggregation::SUM)>> {
using type = Source;
};

Expand Down
11 changes: 2 additions & 9 deletions cpp/src/rolling/detail/rolling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,9 @@ struct DeviceRolling {
static constexpr bool is_supported()
{
return cudf::detail::is_valid_aggregation<T, O>() && has_corresponding_operator<O>() &&
// TODO: Delete all this extra logic once is_valid_aggregation<> cleans up some edge
// cases it isn't handling.
// MIN/MAX supports all fixed width types
// MIN/MAX only supports fixed width types
(((O == aggregation::MIN || O == aggregation::MAX) && cudf::is_fixed_width<T>()) ||

// SUM supports all fixed width types except timestamps
((O == aggregation::SUM) && (cudf::is_fixed_width<T>() && !cudf::is_timestamp<T>())) ||

// MEAN supports numeric and duration
((O == aggregation::MEAN) && (cudf::is_numeric<T>() || cudf::is_duration<T>())));
(O == aggregation::SUM) || (O == aggregation::MEAN));
}

// operations we do support
Expand Down
20 changes: 17 additions & 3 deletions cpp/tests/rolling/empty_input_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,36 @@ TYPED_TEST(TypedRollingEmptyInputTest, EmptyFixedWidthInputs)

/// `SUM` returns 64-bit promoted types for integral/decimal input.
/// For other fixed-width input types, the same type is returned.
/// Timestamp types are not supported.
{
auto aggs = agg_vector_t{};
aggs.emplace_back(sum());

using expected_type = cudf::detail::target_type_t<InputType, cudf::aggregation::SUM>;
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
if constexpr (cudf::is_timestamp<InputType>()) {
EXPECT_THROW(
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>()),
cudf::logic_error);
} else {
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
}
}

/// `MEAN` returns float64 for all numeric types,
/// except for chrono-types, which yield the same chrono-type.
/// except for duration-types, which yield the same duration-type.
/// Timestamp types are not supported.
{
auto aggs = agg_vector_t{};
aggs.emplace_back(mean());

using expected_type = cudf::detail::target_type_t<InputType, cudf::aggregation::MEAN>;
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
if constexpr (cudf::is_timestamp<InputType>()) {
EXPECT_THROW(
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>()),
cudf::logic_error);
} else {
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
}
}

/// For an input type `T`, `COLLECT_LIST` returns a column of type `list<T>`.
Expand Down

0 comments on commit 2bcdb54

Please sign in to comment.