diff --git a/cpp/src/strings/search/find.cu b/cpp/src/strings/search/find.cu index 0f33fcb6fe1..94bc81ec933 100644 --- a/cpp/src/strings/search/find.cu +++ b/cpp/src/strings/search/find.cu @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -347,13 +348,15 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings, string_view const d_target, bool* d_results) { - auto const idx = cudf::detail::grid_1d::global_thread_id(); - using warp_reduce = cub::WarpReduce; - __shared__ typename warp_reduce::TempStorage temp_storage; + auto const idx = cudf::detail::grid_1d::global_thread_id(); auto const str_idx = idx / cudf::detail::warp_size; if (str_idx >= d_strings.size()) { return; } - auto const lane_idx = idx % cudf::detail::warp_size; + + namespace cg = cooperative_groups; + auto const warp = cg::tiled_partition(cg::this_thread_block()); + auto const lane_idx = warp.thread_rank(); + if (d_strings.is_null(str_idx)) { return; } // get the string for this warp auto const d_str = d_strings.element(str_idx); @@ -373,7 +376,7 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings, } } - auto const result = warp_reduce(temp_storage).Reduce(found, cub::Max()); + auto const result = warp.any(found); if (lane_idx == 0) { d_results[str_idx] = result; } }