From 86db9804746fb20554c1900b311a228dc1d6e349 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 14 Oct 2024 16:30:22 -0700 Subject: [PATCH] Clean up hash-groupby `var_hash_functor` (#17034) This work is part of splitting the original bulk shared memory groupby PR #16619. This PR renames the file originally titled `multi_pass_kernels.cuh`, which contains the `var_hash_functor`, to `var_hash_functor.cuh`. It also includes cleanups such as utilizing `cuda::std::` utilities in device code and removing redundant template parameters. Authors: - Yunsong Wang (https://github.com/PointKernel) Approvers: - Vukasin Milovanovic (https://github.com/vuule) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/17034 --- cpp/src/groupby/hash/groupby.cu | 5 +- cpp/src/groupby/hash/groupby_kernels.cuh | 2 - ..._pass_kernels.cuh => var_hash_functor.cuh} | 51 ++++++++----------- 3 files changed, 25 insertions(+), 33 deletions(-) rename cpp/src/groupby/hash/{multi_pass_kernels.cuh => var_hash_functor.cuh} (69%) diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index 75767786272..0432b9d120a 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -16,7 +16,8 @@ #include "flatten_single_pass_aggs.hpp" #include "groupby/common/utils.hpp" -#include "groupby/hash/groupby_kernels.cuh" +#include "groupby_kernels.cuh" +#include "var_hash_functor.cuh" #include #include @@ -261,7 +262,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final rmm::exec_policy(stream), thrust::make_counting_iterator(0), col.size(), - ::cudf::detail::var_hash_functor{ + var_hash_functor{ set, row_bitmask, *var_result_view, *values_view, *sum_view, *count_view, agg._ddof}); sparse_results->add_result(col, agg, std::move(var_result)); dense_results->add_result(col, agg, to_dense_agg_result(agg)); diff --git a/cpp/src/groupby/hash/groupby_kernels.cuh b/cpp/src/groupby/hash/groupby_kernels.cuh index 188d0cff3f1..86f4d76487f 100644 --- a/cpp/src/groupby/hash/groupby_kernels.cuh +++ b/cpp/src/groupby/hash/groupby_kernels.cuh @@ -16,8 +16,6 @@ #pragma once -#include "multi_pass_kernels.cuh" - #include #include #include diff --git a/cpp/src/groupby/hash/multi_pass_kernels.cuh b/cpp/src/groupby/hash/var_hash_functor.cuh similarity index 69% rename from cpp/src/groupby/hash/multi_pass_kernels.cuh rename to cpp/src/groupby/hash/var_hash_functor.cuh index 7043eafdc10..bb55cc9188c 100644 --- a/cpp/src/groupby/hash/multi_pass_kernels.cuh +++ b/cpp/src/groupby/hash/var_hash_functor.cuh @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once #include @@ -21,17 +20,14 @@ #include #include #include -#include #include +#include #include +#include -#include - -namespace cudf { -namespace detail { - -template +namespace cudf::groupby::detail::hash { +template struct var_hash_functor { SetType set; bitmask_type const* __restrict__ row_bitmask; @@ -47,13 +43,13 @@ struct var_hash_functor { column_device_view sum, column_device_view count, size_type ddof) - : set(set), - row_bitmask(row_bitmask), - target(target), - source(source), - sum(sum), - count(count), - ddof(ddof) + : set{set}, + row_bitmask{row_bitmask}, + target{target}, + source{source}, + sum{sum}, + count{count}, + ddof{ddof} { } @@ -64,23 +60,21 @@ struct var_hash_functor { } template - __device__ std::enable_if_t()> operator()(column_device_view const& source, - size_type source_index, - size_type target_index) noexcept + __device__ cuda::std::enable_if_t()> operator()( + column_device_view const& source, size_type source_index, size_type target_index) noexcept { CUDF_UNREACHABLE("Invalid source type for std, var aggregation combination."); } template - __device__ std::enable_if_t()> operator()(column_device_view const& source, - size_type source_index, - size_type target_index) noexcept + __device__ cuda::std::enable_if_t()> operator()( + column_device_view const& source, size_type source_index, size_type target_index) noexcept { - using Target = target_type_t; - using SumType = target_type_t; - using CountType = target_type_t; + using Target = cudf::detail::target_type_t; + using SumType = cudf::detail::target_type_t; + using CountType = cudf::detail::target_type_t; - if (source_has_nulls and source.is_null(source_index)) return; + if (source.is_null(source_index)) return; CountType group_size = count.element(target_index); if (group_size == 0 or group_size - ddof <= 0) return; @@ -91,8 +85,9 @@ struct var_hash_functor { ref.fetch_add(result, cuda::std::memory_order_relaxed); // STD sqrt is applied in finalize() - if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); } + if (target.is_null(target_index)) { target.set_valid(target_index); } } + __device__ inline void operator()(size_type source_index) { if (row_bitmask == nullptr or cudf::bit_is_set(row_bitmask, source_index)) { @@ -110,6 +105,4 @@ struct var_hash_functor { } } }; - -} // namespace detail -} // namespace cudf +} // namespace cudf::groupby::detail::hash