Skip to content

Commit

Permalink
Clean up hash-groupby var_hash_functor (#17034)
Browse files Browse the repository at this point in the history
This work is part of splitting the original bulk shared memory groupby PR #16619.

This PR renames the file originally titled `multi_pass_kernels.cuh`, which contains the `var_hash_functor`, to `var_hash_functor.cuh`. It also includes cleanups such as utilizing `cuda::std::` utilities in device code and removing redundant template parameters.

Authors:
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - David Wendt (https://github.com/davidwendt)

URL: #17034
  • Loading branch information
PointKernel authored Oct 14, 2024
1 parent 44afc51 commit 86db980
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 33 deletions.
5 changes: 3 additions & 2 deletions cpp/src/groupby/hash/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

#include "flatten_single_pass_aggs.hpp"
#include "groupby/common/utils.hpp"
#include "groupby/hash/groupby_kernels.cuh"
#include "groupby_kernels.cuh"
#include "var_hash_functor.cuh"

#include <cudf/aggregation.hpp>
#include <cudf/column/column.hpp>
Expand Down Expand Up @@ -261,7 +262,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final
rmm::exec_policy(stream),
thrust::make_counting_iterator(0),
col.size(),
::cudf::detail::var_hash_functor{
var_hash_functor{
set, row_bitmask, *var_result_view, *values_view, *sum_view, *count_view, agg._ddof});
sparse_results->add_result(col, agg, std::move(var_result));
dense_results->add_result(col, agg, to_dense_agg_result(agg));
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/groupby/hash/groupby_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#pragma once

#include "multi_pass_kernels.cuh"

#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/aggregation/device_aggregators.cuh>
#include <cudf/groupby.hpp>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/table/table_device_view.cuh>
#include <cudf/utilities/type_dispatcher.hpp>

#include <cuco/static_set_ref.cuh>
#include <cuda/atomic>
#include <cuda/std/type_traits>

#include <cmath>

namespace cudf {
namespace detail {

template <typename SetType, bool target_has_nulls = true, bool source_has_nulls = true>
namespace cudf::groupby::detail::hash {
template <typename SetType>
struct var_hash_functor {
SetType set;
bitmask_type const* __restrict__ row_bitmask;
Expand All @@ -47,13 +43,13 @@ struct var_hash_functor {
column_device_view sum,
column_device_view count,
size_type ddof)
: set(set),
row_bitmask(row_bitmask),
target(target),
source(source),
sum(sum),
count(count),
ddof(ddof)
: set{set},
row_bitmask{row_bitmask},
target{target},
source{source},
sum{sum},
count{count},
ddof{ddof}
{
}

Expand All @@ -64,23 +60,21 @@ struct var_hash_functor {
}

template <typename Source>
__device__ std::enable_if_t<!is_supported<Source>()> operator()(column_device_view const& source,
size_type source_index,
size_type target_index) noexcept
__device__ cuda::std::enable_if_t<!is_supported<Source>()> operator()(
column_device_view const& source, size_type source_index, size_type target_index) noexcept
{
CUDF_UNREACHABLE("Invalid source type for std, var aggregation combination.");
}

template <typename Source>
__device__ std::enable_if_t<is_supported<Source>()> operator()(column_device_view const& source,
size_type source_index,
size_type target_index) noexcept
__device__ cuda::std::enable_if_t<is_supported<Source>()> operator()(
column_device_view const& source, size_type source_index, size_type target_index) noexcept
{
using Target = target_type_t<Source, aggregation::VARIANCE>;
using SumType = target_type_t<Source, aggregation::SUM>;
using CountType = target_type_t<Source, aggregation::COUNT_VALID>;
using Target = cudf::detail::target_type_t<Source, aggregation::VARIANCE>;
using SumType = cudf::detail::target_type_t<Source, aggregation::SUM>;
using CountType = cudf::detail::target_type_t<Source, aggregation::COUNT_VALID>;

if (source_has_nulls and source.is_null(source_index)) return;
if (source.is_null(source_index)) return;
CountType group_size = count.element<CountType>(target_index);
if (group_size == 0 or group_size - ddof <= 0) return;

Expand All @@ -91,8 +85,9 @@ struct var_hash_functor {
ref.fetch_add(result, cuda::std::memory_order_relaxed);
// STD sqrt is applied in finalize()

if (target_has_nulls and target.is_null(target_index)) { target.set_valid(target_index); }
if (target.is_null(target_index)) { target.set_valid(target_index); }
}

__device__ inline void operator()(size_type source_index)
{
if (row_bitmask == nullptr or cudf::bit_is_set(row_bitmask, source_index)) {
Expand All @@ -110,6 +105,4 @@ struct var_hash_functor {
}
}
};

} // namespace detail
} // namespace cudf
} // namespace cudf::groupby::detail::hash

0 comments on commit 86db980

Please sign in to comment.