From 38261f8509245f88bdeab193a1357d9c73d765f0 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Fri, 6 Dec 2024 08:32:41 -0500 Subject: [PATCH] Improve strings contains/find performance for smaller strings (#17330) Replaces usage of `cudf::string_view::find()` with loop and call to `cudf::string_view::compare()` where possible. This showed significant performance improvement. This was also slightly faster than a KMP prototype implementation. Also updates the find/contains benchmarks to remove the 2GB limit and include column versions of the find APIs. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Basit Ayantunde (https://github.com/lamarrr) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/17330 --- cpp/benchmarks/string/find.cpp | 59 ++++++++++++++++-------- cpp/include/cudf/strings/string_view.cuh | 17 ++++--- cpp/src/strings/search/find.cu | 24 ++++++---- 3 files changed, 61 insertions(+), 39 deletions(-) diff --git a/cpp/benchmarks/string/find.cpp b/cpp/benchmarks/string/find.cpp index 3ea3ff13a2f..2ba793e998e 100644 --- a/cpp/benchmarks/string/find.cpp +++ b/cpp/benchmarks/string/find.cpp @@ -28,21 +28,19 @@ static void bench_find_string(nvbench::state& state) { - auto const n_rows = static_cast(state.get_int64("num_rows")); - auto const row_width = static_cast(state.get_int64("row_width")); + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const max_width = static_cast(state.get_int64("max_width")); auto const hit_rate = static_cast(state.get_int64("hit_rate")); auto const api = state.get_string("api"); - - if (static_cast(n_rows) * static_cast(row_width) >= - static_cast(std::numeric_limits::max())) { - state.skip("Skip benchmarks greater than size_type limit"); - } + auto const tgt_type = state.get_string("target"); auto const stream = cudf::get_default_stream(); - auto const col = create_string_column(n_rows, row_width, hit_rate); + auto const col = create_string_column(num_rows, max_width, hit_rate); auto const input = cudf::strings_column_view(col->view()); - cudf::string_scalar target("0987 5W43"); + auto target = cudf::string_scalar("0987 5W43"); + auto targets_col = cudf::make_column_from_scalar(target, num_rows); + auto const targets = cudf::strings_column_view(targets_col->view()); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); auto const chars_size = input.chars_size(stream); @@ -55,23 +53,44 @@ static void bench_find_string(nvbench::state& state) } if (api == "find") { - state.exec(nvbench::exec_tag::sync, - [&](nvbench::launch& launch) { cudf::strings::find(input, target); }); + if (tgt_type == "scalar") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::find(input, target); }); + } else if (tgt_type == "column") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::find(input, targets); }); + } } else if (api == "contains") { - state.exec(nvbench::exec_tag::sync, - [&](nvbench::launch& launch) { cudf::strings::contains(input, target); }); + if (tgt_type == "scalar") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::contains(input, target); }); + } else if (tgt_type == "column") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::contains(input, targets); }); + } } else if (api == "starts_with") { - state.exec(nvbench::exec_tag::sync, - [&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); }); + if (tgt_type == "scalar") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); }); + } else if (tgt_type == "column") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::starts_with(input, targets); }); + } } else if (api == "ends_with") { - state.exec(nvbench::exec_tag::sync, - [&](nvbench::launch& launch) { cudf::strings::ends_with(input, target); }); + if (tgt_type == "scalar") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::ends_with(input, target); }); + } else if (tgt_type == "column") { + state.exec(nvbench::exec_tag::sync, + [&](nvbench::launch& launch) { cudf::strings::ends_with(input, targets); }); + } } } NVBENCH_BENCH(bench_find_string) .set_name("find_string") + .add_int64_axis("max_width", {32, 64, 128, 256}) + .add_int64_axis("num_rows", {32768, 262144, 2097152}) + .add_int64_axis("hit_rate", {20, 80}) // percentage .add_string_axis("api", {"find", "contains", "starts_with", "ends_with"}) - .add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024}) - .add_int64_axis("num_rows", {260'000, 1'953'000, 16'777'216}) - .add_int64_axis("hit_rate", {20, 80}); // percentage + .add_string_axis("target", {"scalar", "column"}); diff --git a/cpp/include/cudf/strings/string_view.cuh b/cpp/include/cudf/strings/string_view.cuh index 34ed3c5618e..1ae4c3703b2 100644 --- a/cpp/include/cudf/strings/string_view.cuh +++ b/cpp/include/cudf/strings/string_view.cuh @@ -373,24 +373,23 @@ __device__ inline size_type string_view::find_impl(char const* str, size_type pos, size_type count) const { - auto const nchars = length(); - if (!str || pos < 0 || pos > nchars) return npos; - if (count < 0) count = nchars; + if (!str || pos < 0) { return npos; } + if (pos > 0 && pos > length()) { return npos; } // use iterator to help reduce character/byte counting - auto itr = begin() + pos; + auto const itr = begin() + pos; auto const spos = itr.byte_offset(); - auto const epos = ((pos + count) < nchars) ? (itr + count).byte_offset() : size_bytes(); + auto const epos = + (count >= 0) && ((pos + count) < length()) ? (itr + count).byte_offset() : size_bytes(); auto const find_length = (epos - spos) - bytes + 1; + auto const d_target = string_view{str, bytes}; auto ptr = data() + (forward ? spos : (epos - bytes)); for (size_type idx = 0; idx < find_length; ++idx) { - bool match = true; - for (size_type jdx = 0; match && (jdx < bytes); ++jdx) { - match = (ptr[jdx] == str[jdx]); + if (d_target.compare(ptr, bytes) == 0) { + return forward ? pos : character_offset(epos - bytes - idx); } - if (match) { return forward ? pos : character_offset(epos - bytes - idx); } // use pos to record the current find position pos += strings::detail::is_begin_utf8_char(*ptr); forward ? ++ptr : --ptr; diff --git a/cpp/src/strings/search/find.cu b/cpp/src/strings/search/find.cu index 3cf4970d36e..0f33fcb6fe1 100644 --- a/cpp/src/strings/search/find.cu +++ b/cpp/src/strings/search/find.cu @@ -70,13 +70,11 @@ struct finder_fn { if (d_strings.is_null(idx)) { return -1; } auto const d_str = d_strings.element(idx); if (d_str.empty() && (start > 0)) { return -1; } + if (stop >= 0 && start > stop) { return -1; } auto const d_target = d_targets[idx]; - auto const length = d_str.length(); - auto const begin = (start > length) ? length : start; - auto const end = (stop < 0) || (stop > length) ? length : stop; - return forward ? d_str.find(d_target, begin, end - begin) - : d_str.rfind(d_target, begin, end - begin); + auto const count = (stop < 0) ? stop : (stop - start); + return forward ? d_str.find(d_target, start, count) : d_str.rfind(d_target, start, count); } }; @@ -367,7 +365,7 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings, i += cudf::detail::warp_size * bytes_per_warp) { // check the target matches this part of the d_str data // this is definitely faster for very long strings > 128B - for (auto j = 0; j < bytes_per_warp; j++) { + for (auto j = 0; !found && (j < bytes_per_warp); j++) { if (((i + j + d_target.size_bytes()) <= d_str.size_bytes()) && d_target.compare(d_str.data() + i + j, d_target.size_bytes()) == 0) { found = true; @@ -531,7 +529,6 @@ std::unique_ptr contains_fn(strings_column_view const& strings, results->set_null_count(strings.null_count()); return results; } - } // namespace std::unique_ptr contains(strings_column_view const& input, @@ -541,13 +538,17 @@ std::unique_ptr contains(strings_column_view const& input, { // use warp parallel when the average string width is greater than the threshold if ((input.null_count() < input.size()) && - ((input.chars_size(stream) / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) { + ((input.chars_size(stream) / (input.size() - input.null_count())) > + AVG_CHAR_BYTES_THRESHOLD)) { return contains_warp_parallel(input, target, stream, mr); } // benchmark measurements showed this to be faster for smaller strings auto pfn = [] __device__(string_view d_string, string_view d_target) { - return d_string.find(d_target) != string_view::npos; + for (size_type i = 0; i <= (d_string.size_bytes() - d_target.size_bytes()); ++i) { + if (d_target.compare(d_string.data() + i, d_target.size_bytes()) == 0) { return true; } + } + return false; }; return contains_fn(input, target, pfn, stream, mr); } @@ -558,7 +559,10 @@ std::unique_ptr contains(strings_column_view const& strings, rmm::device_async_resource_ref mr) { auto pfn = [] __device__(string_view d_string, string_view d_target) { - return d_string.find(d_target) != string_view::npos; + for (size_type i = 0; i <= (d_string.size_bytes() - d_target.size_bytes()); ++i) { + if (d_target.compare(d_string.data() + i, d_target.size_bytes()) == 0) { return true; } + } + return false; }; return contains_fn(strings, targets, pfn, stream, mr); }