From 74bbd2641aeecbb47981b5af50d59ac24928e81c Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 6 Dec 2023 10:37:35 -0500 Subject: [PATCH 1/2] Fix unsanitized nulls from strings segmented-reduce --- cpp/src/reductions/segmented/simple.cuh | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/cpp/src/reductions/segmented/simple.cuh b/cpp/src/reductions/segmented/simple.cuh index 05a871ed4fb..a24b23f5c5f 100644 --- a/cpp/src/reductions/segmented/simple.cuh +++ b/cpp/src/reductions/segmented/simple.cuh @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -114,6 +113,23 @@ std::unique_ptr simple_segmented_reduction( return result; } +template +struct reduce_argminmax_fn { + column_device_view const d_col; // column data + bool const arg_min; // true if argmin, otherwise argmax + null_policy null_handler; // include or exclude nulls + + __device__ inline auto operator()(size_type const& lhs_idx, size_type const& rhs_idx) const + { + if (lhs_idx < 0 || lhs_idx >= d_col.size()) { return rhs_idx; } // CUB segmented reduce + if (rhs_idx < 0 || rhs_idx >= d_col.size()) { return lhs_idx; } // calls with OOB indices + if (d_col.is_null(lhs_idx)) { return null_handler == null_policy::INCLUDE ? lhs_idx : rhs_idx; } + if (d_col.is_null(rhs_idx)) { return null_handler == null_policy::INCLUDE ? rhs_idx : lhs_idx; } + auto const less = d_col.element(lhs_idx) < d_col.element(rhs_idx); + return less == arg_min ? lhs_idx : rhs_idx; + } +}; + /** * @brief String segmented reduction for 'min', 'max'. * @@ -130,7 +146,6 @@ std::unique_ptr simple_segmented_reduction( * @param mr Device memory resource used to allocate the returned column's device memory * @return Output column in device memory */ - template || @@ -148,8 +163,7 @@ std::unique_ptr string_segmented_reduction(column_view const& col, auto const num_segments = static_cast(offsets.size()) - 1; bool constexpr is_argmin = std::is_same_v; - auto string_comparator = - cudf::detail::element_argminmax_fn{*device_col, col.has_nulls(), is_argmin}; + auto string_comparator = reduce_argminmax_fn{*device_col, is_argmin, null_handling}; auto constexpr identity = is_argmin ? cudf::detail::ARGMIN_SENTINEL : cudf::detail::ARGMAX_SENTINEL; From 21e80dc1c81f956ef9da046402dde323b5c1ce57 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 8 Dec 2023 12:18:58 -0500 Subject: [PATCH 2/2] move comments --- cpp/src/reductions/segmented/simple.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/reductions/segmented/simple.cuh b/cpp/src/reductions/segmented/simple.cuh index a24b23f5c5f..370724035df 100644 --- a/cpp/src/reductions/segmented/simple.cuh +++ b/cpp/src/reductions/segmented/simple.cuh @@ -121,8 +121,9 @@ struct reduce_argminmax_fn { __device__ inline auto operator()(size_type const& lhs_idx, size_type const& rhs_idx) const { - if (lhs_idx < 0 || lhs_idx >= d_col.size()) { return rhs_idx; } // CUB segmented reduce - if (rhs_idx < 0 || rhs_idx >= d_col.size()) { return lhs_idx; } // calls with OOB indices + // CUB segmented reduce calls with OOB indices + if (lhs_idx < 0 || lhs_idx >= d_col.size()) { return rhs_idx; } + if (rhs_idx < 0 || rhs_idx >= d_col.size()) { return lhs_idx; } if (d_col.is_null(lhs_idx)) { return null_handler == null_policy::INCLUDE ? lhs_idx : rhs_idx; } if (d_col.is_null(rhs_idx)) { return null_handler == null_policy::INCLUDE ? rhs_idx : lhs_idx; } auto const less = d_col.element(lhs_idx) < d_col.element(rhs_idx);