Skip to content

Commit

Permalink
Fix and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codereport committed Sep 24, 2021
1 parent c431650 commit 8f3dfd6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
10 changes: 6 additions & 4 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,14 +954,16 @@ template <typename Source, aggregation::Kind k>
struct target_type_impl<
Source,
k,
std::enable_if_t<is_fixed_width<Source>() && !is_chrono<Source>() && (k == aggregation::MEAN)>> {
std::enable_if_t<is_fixed_width<Source>() && not is_chrono<Source>() &&
not is_fixed_point<Source>() && (k == aggregation::MEAN)>> {
using type = double;
};

template <typename Source, aggregation::Kind k>
struct target_type_impl<Source,
k,
std::enable_if_t<is_chrono<Source>() && (k == aggregation::MEAN)>> {
struct target_type_impl<
Source,
k,
std::enable_if_t<(is_chrono<Source>() or is_fixed_point<Source>()) && (k == aggregation::MEAN)>> {
using type = Source;
};

Expand Down
52 changes: 52 additions & 0 deletions cpp/tests/groupby/mean_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,57 @@ TEST_F(groupby_dictionary_mean_test, basic)
keys, vals, expect_keys, expect_vals, cudf::make_mean_aggregation<groupby_aggregation>());
}

template <typename T>
struct FixedPointTestBothReps : public cudf::test::BaseFixture {
};

TYPED_TEST_CASE(FixedPointTestBothReps, cudf::test::FixedPointTypes);

TYPED_TEST(FixedPointTestBothReps, GroupBySortMeanDecimalAsValue)
{
using namespace numeric;
using decimalXX = TypeParam;
using RepType = cudf::device_storage_type_t<decimalXX>;
using fp_wrapper = cudf::test::fixed_point_column_wrapper<RepType>;

for (auto const i : {2, 1, 0, -1, -2}) {
auto const scale = scale_type{i};
// clang-format off
auto const keys = fixed_width_column_wrapper<K>{1, 2, 3, 1, 2, 2, 1, 3, 3, 2};
auto const vals = fp_wrapper{ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, scale};
// clang-format on

auto const expect_keys = fixed_width_column_wrapper<K>{1, 2, 3};
auto const expect_vals_min = fp_wrapper{{3, 4, 5}, scale};

auto agg = cudf::make_mean_aggregation<cudf::groupby_aggregation>();
test_single_agg(
keys, vals, expect_keys, expect_vals_min, std::move(agg), force_use_sort_impl::YES);
}
}

TYPED_TEST(FixedPointTestBothReps, GroupByHashMeanDecimalAsValue)
{
using namespace numeric;
using decimalXX = TypeParam;
using RepType = cudf::device_storage_type_t<decimalXX>;
using fp_wrapper = cudf::test::fixed_point_column_wrapper<RepType>;
using K = int32_t;

for (auto const i : {2, 1, 0, -1, -2}) {
auto const scale = scale_type{i};
// clang-format off
auto const keys = fixed_width_column_wrapper<K>{1, 2, 3, 1, 2, 2, 1, 3, 3, 2};
auto const vals = fp_wrapper{ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, scale};
// clang-format on

auto const expect_keys = fixed_width_column_wrapper<K>{1, 2, 3};
auto const expect_vals_min = fp_wrapper{{3, 4, 5}, scale};

auto agg = cudf::make_mean_aggregation<cudf::groupby_aggregation>();
test_single_agg(keys, vals, expect_keys, expect_vals_min, std::move(agg));
}
}

} // namespace test
} // namespace cudf

0 comments on commit 8f3dfd6

Please sign in to comment.