From 8b23436e0fc48706d21629003dcf144fb02045fb Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Fri, 8 Dec 2023 15:09:12 -0500 Subject: [PATCH] Fix unsanitized nulls from strings segmented-reduce (#14586) Fixes the string specialization logic in `cudf::segmented_reduce` to not produce unsanitized null entries. The functor used to build a gather map for argmin/argmax was corrected to handle include/exclude nulls correctly. Reference: #14559 Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Nghia Truong (https://github.com/ttnghia) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/14586 --- cpp/src/reductions/segmented/simple.cuh | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/cpp/src/reductions/segmented/simple.cuh b/cpp/src/reductions/segmented/simple.cuh index 05a871ed4fb..370724035df 100644 --- a/cpp/src/reductions/segmented/simple.cuh +++ b/cpp/src/reductions/segmented/simple.cuh @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -114,6 +113,24 @@ std::unique_ptr simple_segmented_reduction( return result; } +template +struct reduce_argminmax_fn { + column_device_view const d_col; // column data + bool const arg_min; // true if argmin, otherwise argmax + null_policy null_handler; // include or exclude nulls + + __device__ inline auto operator()(size_type const& lhs_idx, size_type const& rhs_idx) const + { + // CUB segmented reduce calls with OOB indices + if (lhs_idx < 0 || lhs_idx >= d_col.size()) { return rhs_idx; } + if (rhs_idx < 0 || rhs_idx >= d_col.size()) { return lhs_idx; } + if (d_col.is_null(lhs_idx)) { return null_handler == null_policy::INCLUDE ? lhs_idx : rhs_idx; } + if (d_col.is_null(rhs_idx)) { return null_handler == null_policy::INCLUDE ? rhs_idx : lhs_idx; } + auto const less = d_col.element(lhs_idx) < d_col.element(rhs_idx); + return less == arg_min ? lhs_idx : rhs_idx; + } +}; + /** * @brief String segmented reduction for 'min', 'max'. * @@ -130,7 +147,6 @@ std::unique_ptr simple_segmented_reduction( * @param mr Device memory resource used to allocate the returned column's device memory * @return Output column in device memory */ - template || @@ -148,8 +164,7 @@ std::unique_ptr string_segmented_reduction(column_view const& col, auto const num_segments = static_cast(offsets.size()) - 1; bool constexpr is_argmin = std::is_same_v; - auto string_comparator = - cudf::detail::element_argminmax_fn{*device_col, col.has_nulls(), is_argmin}; + auto string_comparator = reduce_argminmax_fn{*device_col, is_argmin, null_handling}; auto constexpr identity = is_argmin ? cudf::detail::ARGMIN_SENTINEL : cudf::detail::ARGMAX_SENTINEL;