diff --git a/cpp/src/strings/search/find.cu b/cpp/src/strings/search/find.cu index 5e6a273958c..f71c4b6b49e 100644 --- a/cpp/src/strings/search/find.cu +++ b/cpp/src/strings/search/find.cu @@ -274,30 +274,39 @@ namespace { /** * @brief Check if `d_target` appears in a row in `d_strings`. * - * This executes as a warp per string/row. + * This executes as a warp per string/row and performs well for longer strings. + * @see AVG_CHAR_BYTES_THRESHOLD + * + * @param d_strings Column of input strings + * @param d_target String to search for in each row of `d_strings` + * @param d_results Indicates which rows contain `d_target` */ -struct contains_warp_fn { - column_device_view const d_strings; - string_view const d_target; - bool* d_results; +__global__ void contains_warp_parallel_fn(column_device_view const d_strings, + string_view const d_target, + bool* d_results) +{ + size_type const idx = static_cast(threadIdx.x + blockIdx.x * blockDim.x); + using warp_reduce = cub::WarpReduce; + __shared__ typename warp_reduce::TempStorage temp_storage; - __device__ void operator()(std::size_t idx) - { - auto const str_idx = static_cast(idx / cudf::detail::warp_size); - if (d_strings.is_null(str_idx)) { return; } - // get the string for this warp - auto const d_str = d_strings.element(str_idx); - // each thread of the warp will check just part of the string - auto found = false; - for (auto i = static_cast(idx % cudf::detail::warp_size); - !found && (i + d_target.size_bytes()) < d_str.size_bytes(); - i += cudf::detail::warp_size) { - // check the target matches this part of the d_str data - if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; } - } - if (found) { atomicOr(d_results + str_idx, true); } + if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; } + + auto const str_idx = idx / cudf::detail::warp_size; + auto const lane_idx = idx % cudf::detail::warp_size; + if (d_strings.is_null(str_idx)) { return; } + // get the string for this warp + auto const d_str = d_strings.element(str_idx); + // each thread of the warp will check just part of the string + auto found = false; + for (auto i = static_cast(idx % cudf::detail::warp_size); + !found && (i + d_target.size_bytes()) < d_str.size_bytes(); + i += cudf::detail::warp_size) { + // check the target matches this part of the d_str data + if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; } } -}; + auto const result = warp_reduce(temp_storage).Reduce(found, cub::Max()); + if (lane_idx == 0) { d_results[str_idx] = result; } +} std::unique_ptr contains_warp_parallel(strings_column_view const& input, string_scalar const& target, @@ -324,11 +333,11 @@ std::unique_ptr contains_warp_parallel(strings_column_view const& input, if (!d_target.empty()) { // launch warp per string - auto d_strings = column_device_view::create(input.parent(), stream); - thrust::for_each_n(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - static_cast(input.size()) * cudf::detail::warp_size, - contains_warp_fn{*d_strings, d_target, results_view.data()}); + auto const d_strings = column_device_view::create(input.parent(), stream); + constexpr int block_size = 256; + cudf::detail::grid_1d grid{input.size() * cudf::detail::warp_size, block_size}; + contains_warp_parallel_fn<<>>( + *d_strings, d_target, results_view.data()); } results->set_null_count(input.null_count()); return results;