diff --git a/cpp/src/reductions/segmented/simple.cuh b/cpp/src/reductions/segmented/simple.cuh index 05a871ed4fb..370724035df 100644 --- a/cpp/src/reductions/segmented/simple.cuh +++ b/cpp/src/reductions/segmented/simple.cuh @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -114,6 +113,24 @@ std::unique_ptr simple_segmented_reduction( return result; } +template +struct reduce_argminmax_fn { + column_device_view const d_col; // column data + bool const arg_min; // true if argmin, otherwise argmax + null_policy null_handler; // include or exclude nulls + + __device__ inline auto operator()(size_type const& lhs_idx, size_type const& rhs_idx) const + { + // CUB segmented reduce calls with OOB indices + if (lhs_idx < 0 || lhs_idx >= d_col.size()) { return rhs_idx; } + if (rhs_idx < 0 || rhs_idx >= d_col.size()) { return lhs_idx; } + if (d_col.is_null(lhs_idx)) { return null_handler == null_policy::INCLUDE ? lhs_idx : rhs_idx; } + if (d_col.is_null(rhs_idx)) { return null_handler == null_policy::INCLUDE ? rhs_idx : lhs_idx; } + auto const less = d_col.element(lhs_idx) < d_col.element(rhs_idx); + return less == arg_min ? lhs_idx : rhs_idx; + } +}; + /** * @brief String segmented reduction for 'min', 'max'. * @@ -130,7 +147,6 @@ std::unique_ptr simple_segmented_reduction( * @param mr Device memory resource used to allocate the returned column's device memory * @return Output column in device memory */ - template || @@ -148,8 +164,7 @@ std::unique_ptr string_segmented_reduction(column_view const& col, auto const num_segments = static_cast(offsets.size()) - 1; bool constexpr is_argmin = std::is_same_v; - auto string_comparator = - cudf::detail::element_argminmax_fn{*device_col, col.has_nulls(), is_argmin}; + auto string_comparator = reduce_argminmax_fn{*device_col, is_argmin, null_handling}; auto constexpr identity = is_argmin ? cudf::detail::ARGMIN_SENTINEL : cudf::detail::ARGMAX_SENTINEL;