Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SUM/MEAN aggregation type support. #12503

Merged
Merged
6 changes: 3 additions & 3 deletions cpp/include/cudf/detail/aggregation/aggregation.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -232,8 +232,8 @@ struct update_target_element<
aggregation::SUM,
target_has_nulls,
source_has_nulls,
std::enable_if_t<is_fixed_width<Source>() && cudf::has_atomic_support<Source>() &&
!is_fixed_point<Source>()>> {
std::enable_if_t<cudf::is_fixed_width<Source>() && cudf::has_atomic_support<Source>() &&
!cudf::is_fixed_point<Source>() && !cudf::is_timestamp<Source>()>> {
__device__ void operator()(mutable_column_device_view target,
size_type target_index,
column_device_view source,
Expand Down
21 changes: 10 additions & 11 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1154,9 +1154,7 @@ struct target_type_impl<Source, aggregation::ALL> {
using type = bool;
};

// Always use `double` for MEAN
// Except for chrono types where result is chrono. (Use FloorDiv)
// TODO: MEAN should be only be enabled for duration types - not for timestamps
// Always use `double` for MEAN except for durations and fixed point types.
template <typename Source, aggregation::Kind k>
struct target_type_impl<
Source,
Expand All @@ -1167,10 +1165,10 @@ struct target_type_impl<
};

template <typename Source, aggregation::Kind k>
struct target_type_impl<
Source,
k,
std::enable_if_t<(is_chrono<Source>() or is_fixed_point<Source>()) && (k == aggregation::MEAN)>> {
struct target_type_impl<Source,
k,
std::enable_if_t<(is_duration<Source>() or is_fixed_point<Source>()) &&
(k == aggregation::MEAN)>> {
using type = Source;
};

Expand Down Expand Up @@ -1206,10 +1204,11 @@ struct target_type_impl<
using type = Source;
};

// Summing/Multiplying chrono types, use same type accumulator
// TODO: Sum/Product should only be enabled for duration types - not for timestamps
// Summing duration types, use same type accumulator
template <typename Source, aggregation::Kind k>
struct target_type_impl<Source, k, std::enable_if_t<is_chrono<Source>() && is_sum_product_agg(k)>> {
struct target_type_impl<Source,
k,
std::enable_if_t<is_duration<Source>() && (k == aggregation::SUM)>> {
using type = Source;
};

Expand Down
11 changes: 2 additions & 9 deletions cpp/src/rolling/detail/rolling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,9 @@ struct DeviceRolling {
static constexpr bool is_supported()
{
return cudf::detail::is_valid_aggregation<T, O>() && has_corresponding_operator<O>() &&
// TODO: Delete all this extra logic once is_valid_aggregation<> cleans up some edge
// cases it isn't handling.
// MIN/MAX supports all fixed width types
// MIN/MAX only supports fixed width types
(((O == aggregation::MIN || O == aggregation::MAX) && cudf::is_fixed_width<T>()) ||

// SUM supports all fixed width types except timestamps
((O == aggregation::SUM) && (cudf::is_fixed_width<T>() && !cudf::is_timestamp<T>())) ||

// MEAN supports numeric and duration
((O == aggregation::MEAN) && (cudf::is_numeric<T>() || cudf::is_duration<T>())));
(O == aggregation::SUM) || (O == aggregation::MEAN));
}

// operations we do support
Expand Down
20 changes: 17 additions & 3 deletions cpp/tests/rolling/empty_input_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,22 +183,36 @@ TYPED_TEST(TypedRollingEmptyInputTest, EmptyFixedWidthInputs)

/// `SUM` returns 64-bit promoted types for integral/decimal input.
/// For other fixed-width input types, the same type is returned.
/// Timestamp types are not supported.
{
auto aggs = agg_vector_t{};
aggs.emplace_back(sum());

using expected_type = cudf::detail::target_type_t<InputType, cudf::aggregation::SUM>;
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
if constexpr (cudf::is_timestamp<InputType>()) {
EXPECT_THROW(
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>()),
cudf::logic_error);
} else {
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
}
}

/// `MEAN` returns float64 for all numeric types,
/// except for chrono-types, which yield the same chrono-type.
/// except for duration-types, which yield the same duration-type.
/// Timestamp types are not supported.
{
auto aggs = agg_vector_t{};
aggs.emplace_back(mean());

using expected_type = cudf::detail::target_type_t<InputType, cudf::aggregation::MEAN>;
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
if constexpr (cudf::is_timestamp<InputType>()) {
EXPECT_THROW(
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>()),
cudf::logic_error);
} else {
rolling_output_type_matches(empty_input, aggs, cudf::type_to_id<expected_type>());
}
}

/// For an input type `T`, `COLLECT_LIST` returns a column of type `list<T>`.
Expand Down