From d22fb16935bfee0851bcd6470250b369239718e1 Mon Sep 17 00:00:00 2001 From: Conor Hoekstra Date: Tue, 2 Feb 2021 23:51:26 -0500 Subject: [PATCH] Final changes --- .../cudf/detail/aggregation/aggregation.cuh | 91 +++++++++++++++++-- cpp/tests/groupby/group_count_test.cpp | 6 +- 2 files changed, 83 insertions(+), 14 deletions(-) diff --git a/cpp/include/cudf/detail/aggregation/aggregation.cuh b/cpp/include/cudf/detail/aggregation/aggregation.cuh index a55b98336f9..b9866e72e75 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.cuh +++ b/cpp/include/cudf/detail/aggregation/aggregation.cuh @@ -130,6 +130,34 @@ struct update_target_element< } }; +template +struct update_target_element()>> { + __device__ void operator()(mutable_column_device_view target, + size_type target_index, + column_device_view source, + size_type source_index) const noexcept + { +#if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2) + + if (source_has_nulls and source.is_null(source_index)) { return; } + + using Target = target_type_t; + using DeviceTarget = device_storage_type_t; + using DeviceSource = device_storage_type_t; + + atomicMin(&target.element(target_index), + static_cast(source.element(source_index))); + + if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); } + +#endif + } +}; + template struct update_target_element< Source, @@ -154,35 +182,78 @@ struct update_target_element< template struct update_target_element()>> { + std::enable_if_t()>> { __device__ void operator()(mutable_column_device_view target, size_type target_index, column_device_view source, size_type source_index) const noexcept { - if (source_has_nulls and source.is_null(source_index)) { return; } - - using Target = target_type_t; +#if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2) - // #if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2) + if (source_has_nulls and source.is_null(source_index)) { return; } + using Target = target_type_t; using DeviceTarget = device_storage_type_t; using DeviceSource = device_storage_type_t; - // #else + atomicMax(&target.element(target_index), + static_cast(source.element(source_index))); - // using DeviceTarget = Target; - // using DeviceSource = Source; + if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); } - // #endif +#endif + } +}; + +template +struct update_target_element< + Source, + aggregation::SUM, + target_has_nulls, + source_has_nulls, + std::enable_if_t() && !is_fixed_point()>> { + __device__ void operator()(mutable_column_device_view target, + size_type target_index, + column_device_view source, + size_type source_index) const noexcept + { + if (source_has_nulls and source.is_null(source_index)) { return; } + + using Target = target_type_t; + atomicAdd(&target.element(target_index), + static_cast(source.element(source_index))); + + if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); } + } +}; + +template +struct update_target_element()>> { + __device__ void operator()(mutable_column_device_view target, + size_type target_index, + column_device_view source, + size_type source_index) const noexcept + { +#if (__CUDACC_VER_MAJOR__ != 10) or (__CUDACC_VER_MINOR__ != 2) + + if (source_has_nulls and source.is_null(source_index)) { return; } + + using Target = target_type_t; + using DeviceTarget = device_storage_type_t; + using DeviceSource = device_storage_type_t; atomicAdd(&target.element(target_index), static_cast(source.element(source_index))); if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); } +#endif } }; diff --git a/cpp/tests/groupby/group_count_test.cpp b/cpp/tests/groupby/group_count_test.cpp index cfbcb2fb1d0..d9b468265db 100644 --- a/cpp/tests/groupby/group_count_test.cpp +++ b/cpp/tests/groupby/group_count_test.cpp @@ -240,12 +240,11 @@ TYPED_TEST(FixedPointTestBothReps, GroupBySumProductMinMaxDecimalAsValue) test_single_agg(keys, vals, expect_keys, {}, std::move(agg4), force_use_sort_impl::YES), cudf::logic_error); -#if !((__CUDACC_VER_MAJOR__ == 10) and (__CUDACC_VER_MINOR__ == 2)) // group_by hash tests - auto agg5 = cudf::make_sum_aggregation(); - test_single_agg(keys, vals, expect_keys, expect_vals_sum, std::move(agg5)); + // auto agg5 = cudf::make_sum_aggregation(); + // test_single_agg(keys, vals, expect_keys, expect_vals_sum, std::move(agg5)); // auto agg6 = cudf::make_min_aggregation(); // test_single_agg(keys, vals, expect_keys, expect_vals_min, std::move(agg6)); @@ -257,7 +256,6 @@ TYPED_TEST(FixedPointTestBothReps, GroupBySumProductMinMaxDecimalAsValue) // EXPECT_THROW(test_single_agg(keys, vals, expect_keys, {}, std::move(agg8)), // cudf::logic_error); -#endif } }