Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvement in cudf::strings::find/rfind for long strings #13226

Merged
merged 16 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/benchmarks/string/find.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ static void generate_bench_args(benchmark::internal::Benchmark* b)
int const row_mult = 8;
int const min_rowlen = 1 << 5;
int const max_rowlen = 1 << 13;
int const len_mult = 4;
int const len_mult = 2;
for (int row_count = min_rows; row_count <= max_rows; row_count *= row_mult) {
for (int rowlen = min_rowlen; rowlen <= max_rowlen; rowlen *= len_mult) {
// avoid generating combinations that exceed the cudf column limit
Expand Down
36 changes: 27 additions & 9 deletions cpp/include/cudf/strings/string_view.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ __device__ inline size_type characters_in_string(const char* str, size_type byte
#endif
}

/**
* @brief Count the bytes to a specified character position
*
* Returns the number of bytes and any left over position value.
* The returned position is > 0 if the given position would read past
* the end of the input string.
*
* @param d_str Input string to count bytes within
* @param pos Character position to count to
* @return The number of bytes and the left over non-counted position value
*/
__device__ inline std::pair<size_type, size_type> bytes_to_character_position(string_view d_str,
size_type pos)
{
size_type bytes = 0;
auto ptr = d_str.data();
auto const end_ptr = ptr + d_str.size_bytes();
while ((pos > 0) && (ptr < end_ptr)) {
auto const width = strings::detail::bytes_in_utf8_byte(static_cast<uint8_t>(*ptr));
if (width) { --pos; }
bytes += width;
++ptr;
}
return {bytes, pos};
}

/**
* @brief string value for sentinel which is used in min, max reduction
* operators
Expand Down Expand Up @@ -266,16 +292,8 @@ __device__ inline char_utf8 string_view::operator[](size_type pos) const

__device__ inline size_type string_view::byte_offset(size_type pos) const
{
size_type offset = 0;
const char* sptr = _data;
const char* eptr = sptr + _bytes;
if (length() == size_bytes()) return pos;
while ((pos > 0) && (sptr < eptr)) {
size_type charbytes = strings::detail::bytes_in_utf8_byte(static_cast<uint8_t>(*sptr++));
if (charbytes) --pos;
offset += charbytes;
}
return offset;
return std::get<0>(strings::detail::bytes_to_character_position(*this, pos));
}

__device__ inline int string_view::compare(const string_view& in) const
Expand Down
244 changes: 169 additions & 75 deletions cpp/src/strings/search/find.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,103 +41,207 @@ namespace cudf {
namespace strings {
namespace detail {
namespace {

/**
* @brief Utility to return integer column indicating the position of
* target string within each string in a strings column.
* @brief Threshold to decide on using string or warp parallel functions.
*
* Null string entries return corresponding null output column entries.
* If the average byte length of a string in a column exceeds this value then
* a warp-parallel function is used.
*
* @tparam FindFunction Returns integer character position value given a string and target.
* Note that this value is shared by find, rfind, and contains functions
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
*/
constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 64;

/**
* @brief Find function handles a string per thread
*/
template <bool forward = true>
struct finder_fn {
column_device_view const d_strings;
string_view const d_target;
size_type const start;
size_type const stop;

__device__ size_type operator()(size_type idx)
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
{
if (d_strings.is_null(idx)) { return -1; }
auto d_str = d_strings.element<string_view>(idx);

auto const length = d_str.length();
auto const begin = (start > length) ? length : start;
auto const end = (stop < 0) || (stop > length) ? length : stop;
return forward ? d_str.find(d_target, begin, end - begin)
: d_str.rfind(d_target, begin, end - begin);
}
};

/**
* @brief Special logic handles an empty target for find/rfind
*
* @param strings Strings column to search for target.
* @param target String to search for in each string in the strings column.
* @param start First character position to start the search.
* @param stop Last character position (exclusive) to end the search.
* @param pfn Functor used for locating `target` in each string.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New integer column with character position values.
* where length = number of characters in the input string
* if forward = true:
* return start iff (start <= length), otherwise return -1
* if forward = false:
* return stop iff (0 <= stop <= length), otherwise return length
*/
template <typename FindFunction>
std::unique_ptr<column> find_fn(strings_column_view const& strings,
template <bool forward = true>
struct empty_target_fn {
column_device_view const d_strings;
size_type const start;
size_type const stop;

__device__ size_type operator()(size_type idx)
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
{
if (d_strings.is_null(idx)) { return -1; }
auto d_str = d_strings.element<string_view>(idx);

// common case shortcut
if (forward && start == 0) { return 0; }

auto const length = d_str.length();
if (start > length) { return -1; }
if constexpr (forward) { return start; }

return (stop < 0) || (stop > length) ? length : stop;
}
};

/**
* @brief String per warp function for find/rfind
*/
template <bool forward = true>
__global__ void finder_warp_parallel_fn(column_device_view const d_strings,
string_view const d_target,
size_type const start,
size_type const stop,
size_type* d_results)
{
size_type const idx = static_cast<size_type>(threadIdx.x + blockIdx.x * blockDim.x);

if (idx >= (d_strings.size() * cudf::detail::warp_size)) { return; }

auto const str_idx = idx / cudf::detail::warp_size;
auto const lane_idx = idx % cudf::detail::warp_size;

if (d_strings.is_null(str_idx)) { return; }

// initialize the output for the atomicMin/Max
if (lane_idx == 0) { d_results[str_idx] = forward ? std::numeric_limits<size_type>::max() : -1; }
__syncwarp();

auto const d_str = d_strings.element<string_view>(str_idx);

auto const [begin, left_over] = bytes_to_character_position(d_str, start);
auto const start_char_pos = start - left_over; // keep track of character position

auto const end = [d_str, start, stop, begin = begin] {
if (stop < 0) { return d_str.size_bytes(); }
if (stop <= start) { return begin; }
// we count from `begin` instead of recounting from the beginning of the string
return begin + std::get<0>(bytes_to_character_position(
string_view(d_str.data() + begin, d_str.size_bytes() - begin), stop - start));
}();

// each thread compares the target with the thread's individual starting byte
size_type position = forward ? std::numeric_limits<size_type>::max() : -1;
for (auto itr = begin + lane_idx; itr + d_target.size_bytes() <= end;
itr += cudf::detail::warp_size) {
if (d_target.compare(d_str.data() + itr, d_target.size_bytes()) == 0) {
position = itr;
if (forward) break;
}
}

// find stores the minimum position while rfind stores the maximum position
// note that this was slightly faster than using cub::WarpReduce
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
forward ? atomicMin(d_results + str_idx, position) : atomicMax(d_results + str_idx, position);
__syncwarp();

if (lane_idx == 0) {
// the final result needs to be fixed up convert max() to -1
// and a byte position to a character position
auto const result = d_results[str_idx];
d_results[str_idx] =
((result < std::numeric_limits<size_type>::max()) && (result >= begin))
? start_char_pos + characters_in_string(d_str.data() + begin, result - begin)
: -1;
}
}

template <bool forward = true>
std::unique_ptr<column> find_fn(strings_column_view const& input,
string_scalar const& target,
size_type start,
size_type stop,
FindFunction& pfn,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid.");
CUDF_EXPECTS(start >= 0, "Parameter start must be positive integer or zero.");
if ((stop > 0) && (start > stop)) CUDF_FAIL("Parameter start must be less than stop.");
//
auto d_target = string_view(target.data(), target.size());
auto strings_column = column_device_view::create(strings.parent(), stream);
auto d_strings = *strings_column;
auto strings_count = strings.size();

auto d_target = string_view(target.data(), target.size());
auto d_strings = column_device_view::create(input.parent(), stream);

// create output column
auto results = make_numeric_column(data_type{type_id::INT32},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
auto results = make_numeric_column(data_type{type_to_id<size_type>()},
input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
input.null_count(),
stream,
mr);
auto results_view = results->mutable_view();
auto d_results = results_view.data<int32_t>();
// set the position values by evaluating the passed function
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
[d_strings, pfn, d_target, start, stop] __device__(size_type idx) {
int32_t position = -1;
if (!d_strings.is_null(idx))
position = static_cast<int32_t>(
pfn(d_strings.element<string_view>(idx), d_target, start, stop));
return position;
});
results->set_null_count(strings.null_count());
// if input is empty or all-null then we are done
if (input.size() == input.null_count()) { return results; }

auto d_results = results->mutable_view().data<size_type>();

if (d_target.empty()) {
// special logic for empty target results
thrust::transform(rmm::exec_policy(stream),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(input.size()),
d_results,
empty_target_fn<forward>{*d_strings, start, stop});
} else if ((input.chars_size() / (input.size() - input.null_count())) >
AVG_CHAR_BYTES_THRESHOLD) {
// warp-per-string runs faster for longer strings (but not shorter ones)
constexpr int block_size = 256;
cudf::detail::grid_1d grid{input.size() * cudf::detail::warp_size, block_size};
finder_warp_parallel_fn<forward>
<<<grid.num_blocks, grid.num_threads_per_block, 0, stream.value()>>>(
*d_strings, d_target, start, stop, d_results);
} else {
// string-per-thread function
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
d_results,
finder_fn<forward>{*d_strings, d_target, start, stop});
}

results->set_null_count(input.null_count());
return results;
}

} // namespace

std::unique_ptr<column> find(strings_column_view const& strings,
std::unique_ptr<column> find(strings_column_view const& input,
string_scalar const& target,
size_type start,
size_type stop,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto pfn = [] __device__(
string_view d_string, string_view d_target, size_type start, size_type stop) {
size_type length = d_string.length();
if (d_target.empty()) return start > length ? -1 : start;
size_type begin = (start > length) ? length : start;
size_type end = (stop < 0) || (stop > length) ? length : stop;
return d_string.find(d_target, begin, end - begin);
};

return find_fn(strings, target, start, stop, pfn, stream, mr);
return find_fn<true>(input, target, start, stop, stream, mr);
}

std::unique_ptr<column> rfind(strings_column_view const& strings,
std::unique_ptr<column> rfind(strings_column_view const& input,
string_scalar const& target,
size_type start,
size_type stop,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto pfn = [] __device__(
string_view d_string, string_view d_target, size_type start, size_type stop) {
size_type length = d_string.length();
size_type begin = (start > length) ? length : start;
size_type end = (stop < 0) || (stop > length) ? length : stop;
if (d_target.empty()) return start > length ? -1 : end;
return d_string.rfind(d_target, begin, end - begin);
};

return find_fn(strings, target, start, stop, pfn, stream, mr);
return find_fn<false>(input, target, start, stop, stream, mr);
}

} // namespace detail
Expand Down Expand Up @@ -167,17 +271,6 @@ std::unique_ptr<column> rfind(strings_column_view const& strings,
namespace detail {
namespace {

/**
* @brief Threshold to decide on using string or warp parallel functions.
*
* If the average byte length of a string in a column exceeds this value then
* the warp-parallel `contains_warp_fn` function is used.
* Otherwise, the string-parallel function in `contains_fn` is used.
*
* This is only used for the scalar version of `contains()` right now.
*/
constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 64;

/**
* @brief Check if `d_target` appears in a row in `d_strings`.
*
Expand Down Expand Up @@ -370,7 +463,8 @@ std::unique_ptr<column> contains(strings_column_view const& input,
rmm::mr::device_memory_resource* mr)
{
// use warp parallel when the average string width is greater than the threshold
if (!input.is_empty() && ((input.chars_size() / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
if ((input.null_count() < input.size()) &&
((input.chars_size() / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
return contains_warp_parallel(input, target, stream, mr);
}

Expand Down
Loading