Skip to content

Commit

Permalink
Improve strings contains/find performance for smaller strings (#17330)
Browse files Browse the repository at this point in the history
Replaces usage of `cudf::string_view::find()` with loop and call to `cudf::string_view::compare()` where possible.
This showed significant performance improvement. 
This was also slightly faster than a KMP prototype implementation.
Also updates the find/contains benchmarks to remove the 2GB limit and include column versions of the find APIs.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Basit Ayantunde (https://github.com/lamarrr)
  - Bradley Dice (https://github.com/bdice)

URL: #17330
  • Loading branch information
davidwendt authored Dec 6, 2024
1 parent 169a45a commit 38261f8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 39 deletions.
59 changes: 39 additions & 20 deletions cpp/benchmarks/string/find.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,19 @@

static void bench_find_string(nvbench::state& state)
{
auto const n_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const row_width = static_cast<cudf::size_type>(state.get_int64("row_width"));
auto const num_rows = static_cast<cudf::size_type>(state.get_int64("num_rows"));
auto const max_width = static_cast<cudf::size_type>(state.get_int64("max_width"));
auto const hit_rate = static_cast<cudf::size_type>(state.get_int64("hit_rate"));
auto const api = state.get_string("api");

if (static_cast<std::size_t>(n_rows) * static_cast<std::size_t>(row_width) >=
static_cast<std::size_t>(std::numeric_limits<cudf::size_type>::max())) {
state.skip("Skip benchmarks greater than size_type limit");
}
auto const tgt_type = state.get_string("target");

auto const stream = cudf::get_default_stream();
auto const col = create_string_column(n_rows, row_width, hit_rate);
auto const col = create_string_column(num_rows, max_width, hit_rate);
auto const input = cudf::strings_column_view(col->view());

cudf::string_scalar target("0987 5W43");
auto target = cudf::string_scalar("0987 5W43");
auto targets_col = cudf::make_column_from_scalar(target, num_rows);
auto const targets = cudf::strings_column_view(targets_col->view());

state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value()));
auto const chars_size = input.chars_size(stream);
Expand All @@ -55,23 +53,44 @@ static void bench_find_string(nvbench::state& state)
}

if (api == "find") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::find(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::find(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::find(input, targets); });
}
} else if (api == "contains") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::contains(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::contains(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::contains(input, targets); });
}
} else if (api == "starts_with") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::starts_with(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::starts_with(input, targets); });
}
} else if (api == "ends_with") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::ends_with(input, target); });
if (tgt_type == "scalar") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::ends_with(input, target); });
} else if (tgt_type == "column") {
state.exec(nvbench::exec_tag::sync,
[&](nvbench::launch& launch) { cudf::strings::ends_with(input, targets); });
}
}
}

NVBENCH_BENCH(bench_find_string)
.set_name("find_string")
.add_int64_axis("max_width", {32, 64, 128, 256})
.add_int64_axis("num_rows", {32768, 262144, 2097152})
.add_int64_axis("hit_rate", {20, 80}) // percentage
.add_string_axis("api", {"find", "contains", "starts_with", "ends_with"})
.add_int64_axis("row_width", {32, 64, 128, 256, 512, 1024})
.add_int64_axis("num_rows", {260'000, 1'953'000, 16'777'216})
.add_int64_axis("hit_rate", {20, 80}); // percentage
.add_string_axis("target", {"scalar", "column"});
17 changes: 8 additions & 9 deletions cpp/include/cudf/strings/string_view.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -373,24 +373,23 @@ __device__ inline size_type string_view::find_impl(char const* str,
size_type pos,
size_type count) const
{
auto const nchars = length();
if (!str || pos < 0 || pos > nchars) return npos;
if (count < 0) count = nchars;
if (!str || pos < 0) { return npos; }
if (pos > 0 && pos > length()) { return npos; }

// use iterator to help reduce character/byte counting
auto itr = begin() + pos;
auto const itr = begin() + pos;
auto const spos = itr.byte_offset();
auto const epos = ((pos + count) < nchars) ? (itr + count).byte_offset() : size_bytes();
auto const epos =
(count >= 0) && ((pos + count) < length()) ? (itr + count).byte_offset() : size_bytes();

auto const find_length = (epos - spos) - bytes + 1;
auto const d_target = string_view{str, bytes};

auto ptr = data() + (forward ? spos : (epos - bytes));
for (size_type idx = 0; idx < find_length; ++idx) {
bool match = true;
for (size_type jdx = 0; match && (jdx < bytes); ++jdx) {
match = (ptr[jdx] == str[jdx]);
if (d_target.compare(ptr, bytes) == 0) {
return forward ? pos : character_offset(epos - bytes - idx);
}
if (match) { return forward ? pos : character_offset(epos - bytes - idx); }
// use pos to record the current find position
pos += strings::detail::is_begin_utf8_char(*ptr);
forward ? ++ptr : --ptr;
Expand Down
24 changes: 14 additions & 10 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@ struct finder_fn {
if (d_strings.is_null(idx)) { return -1; }
auto const d_str = d_strings.element<string_view>(idx);
if (d_str.empty() && (start > 0)) { return -1; }
if (stop >= 0 && start > stop) { return -1; }
auto const d_target = d_targets[idx];

auto const length = d_str.length();
auto const begin = (start > length) ? length : start;
auto const end = (stop < 0) || (stop > length) ? length : stop;
return forward ? d_str.find(d_target, begin, end - begin)
: d_str.rfind(d_target, begin, end - begin);
auto const count = (stop < 0) ? stop : (stop - start);
return forward ? d_str.find(d_target, start, count) : d_str.rfind(d_target, start, count);
}
};

Expand Down Expand Up @@ -367,7 +365,7 @@ CUDF_KERNEL void contains_warp_parallel_fn(column_device_view const d_strings,
i += cudf::detail::warp_size * bytes_per_warp) {
// check the target matches this part of the d_str data
// this is definitely faster for very long strings > 128B
for (auto j = 0; j < bytes_per_warp; j++) {
for (auto j = 0; !found && (j < bytes_per_warp); j++) {
if (((i + j + d_target.size_bytes()) <= d_str.size_bytes()) &&
d_target.compare(d_str.data() + i + j, d_target.size_bytes()) == 0) {
found = true;
Expand Down Expand Up @@ -531,7 +529,6 @@ std::unique_ptr<column> contains_fn(strings_column_view const& strings,
results->set_null_count(strings.null_count());
return results;
}

} // namespace

std::unique_ptr<column> contains(strings_column_view const& input,
Expand All @@ -541,13 +538,17 @@ std::unique_ptr<column> contains(strings_column_view const& input,
{
// use warp parallel when the average string width is greater than the threshold
if ((input.null_count() < input.size()) &&
((input.chars_size(stream) / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
((input.chars_size(stream) / (input.size() - input.null_count())) >
AVG_CHAR_BYTES_THRESHOLD)) {
return contains_warp_parallel(input, target, stream, mr);
}

// benchmark measurements showed this to be faster for smaller strings
auto pfn = [] __device__(string_view d_string, string_view d_target) {
return d_string.find(d_target) != string_view::npos;
for (size_type i = 0; i <= (d_string.size_bytes() - d_target.size_bytes()); ++i) {
if (d_target.compare(d_string.data() + i, d_target.size_bytes()) == 0) { return true; }
}
return false;
};
return contains_fn(input, target, pfn, stream, mr);
}
Expand All @@ -558,7 +559,10 @@ std::unique_ptr<column> contains(strings_column_view const& strings,
rmm::device_async_resource_ref mr)
{
auto pfn = [] __device__(string_view d_string, string_view d_target) {
return d_string.find(d_target) != string_view::npos;
for (size_type i = 0; i <= (d_string.size_bytes() - d_target.size_bytes()); ++i) {
if (d_target.compare(d_string.data() + i, d_target.size_bytes()) == 0) { return true; }
}
return false;
};
return contains_fn(strings, targets, pfn, stream, mr);
}
Expand Down

0 comments on commit 38261f8

Please sign in to comment.