diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index 773430953c9..c4ffa7f0fb1 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -114,6 +115,26 @@ std::unique_ptr matches_re( return regex_dispatcher(*d_prog, contains_dispatch_fn{*d_prog, true}, input, stream, mr); } +std::unique_ptr count_re(strings_column_view const& input, + std::string const& pattern, + regex_flags const flags, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // compile regex into device object + auto d_prog = + reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); + + auto const d_strings = column_device_view::create(input.parent(), stream); + + auto result = count_matches(*d_strings, *d_prog, input.size(), stream, mr); + if (input.has_nulls()) { + result->set_null_mask(cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count()); + } + return result; +} + } // namespace detail // external APIs @@ -136,78 +157,6 @@ std::unique_ptr matches_re(strings_column_view const& strings, return detail::matches_re(strings, pattern, flags, rmm::cuda_stream_default, mr); } -namespace detail { -namespace { -/** - * @brief This counts the number of times the regex pattern matches in each string. - */ -template -struct count_fn { - reprog_device prog; - column_device_view const d_strings; - - __device__ int32_t operator()(unsigned int idx) - { - if (d_strings.is_null(idx)) return 0; - auto const d_str = d_strings.element(idx); - auto const nchars = d_str.length(); - int32_t find_count = 0; - int32_t begin = 0; - while (begin < nchars) { - auto end = static_cast(nchars); - if (prog.find(idx, d_str, begin, end) <= 0) break; - ++find_count; - begin = end > begin ? end : begin + 1; - } - return find_count; - } -}; - -struct count_dispatch_fn { - reprog_device d_prog; - - template - std::unique_ptr operator()(strings_column_view const& input, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto results = make_numeric_column(data_type{type_id::INT32}, - input.size(), - cudf::detail::copy_bitmask(input.parent(), stream, mr), - input.null_count(), - stream, - mr); - - auto const d_strings = column_device_view::create(input.parent(), stream); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(input.size()), - results->mutable_view().data(), - count_fn{d_prog, *d_strings}); - return results; - } -}; - -} // namespace - -std::unique_ptr count_re( - strings_column_view const& input, - std::string const& pattern, - regex_flags const flags, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) -{ - // compile regex into device object - auto d_prog = - reprog_device::create(pattern, flags, get_character_flags_table(), input.size(), stream); - - return regex_dispatcher(*d_prog, count_dispatch_fn{*d_prog}, input, stream, mr); -} - -} // namespace detail - -// external API - std::unique_ptr count_re(strings_column_view const& strings, std::string const& pattern, regex_flags const flags, diff --git a/cpp/src/strings/count_matches.cu b/cpp/src/strings/count_matches.cu index 5057df7f92b..a850315dfec 100644 --- a/cpp/src/strings/count_matches.cu +++ b/cpp/src/strings/count_matches.cu @@ -43,15 +43,16 @@ struct count_matches_fn { __device__ size_type operator()(size_type idx) { if (d_strings.is_null(idx)) { return 0; } - size_type count = 0; - auto const d_str = d_strings.element(idx); + size_type count = 0; + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); int32_t begin = 0; - int32_t end = d_str.length(); + int32_t end = nchars; while ((begin < end) && (prog.find(idx, d_str, begin, end) > 0)) { ++count; begin = end + (begin == end); - end = d_str.length(); + end = nchars; } return count; } @@ -62,11 +63,14 @@ struct count_dispatch_fn { template std::unique_ptr operator()(column_device_view const& d_strings, + size_type output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { + assert(output_size >= d_strings.size() and "Unexpected output size"); + auto results = make_numeric_column( - data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr); + data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr); thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), @@ -80,21 +84,15 @@ struct count_dispatch_fn { } // namespace /** - * @brief Returns a column of regex match counts for each string in the given column. - * - * A null entry will result in a zero count for that output row. - * - * @param d_strings Device view of the input strings column. - * @param d_prog Regex instance to evaluate on each string. - * @param stream CUDA stream used for device memory operations and kernel launches. - * @param mr Device memory resource used to allocate the returned column's device memory. + * @copydoc cudf::strings::detail::count_matches */ std::unique_ptr count_matches(column_device_view const& d_strings, reprog_device const& d_prog, + size_type output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, stream, mr); + return regex_dispatcher(d_prog, count_dispatch_fn{d_prog}, d_strings, output_size, stream, mr); } } // namespace detail diff --git a/cpp/src/strings/count_matches.hpp b/cpp/src/strings/count_matches.hpp index 1339f2b1ebd..efff3958c65 100644 --- a/cpp/src/strings/count_matches.hpp +++ b/cpp/src/strings/count_matches.hpp @@ -36,12 +36,14 @@ class reprog_device; * * @param d_strings Device view of the input strings column. * @param d_prog Regex instance to evaluate on each string. + * @param output_size Number of rows for the output column. * @param stream CUDA stream used for device memory operations and kernel launches. * @param mr Device memory resource used to allocate the returned column's device memory. */ std::unique_ptr count_matches( column_device_view const& d_strings, reprog_device const& d_prog, + size_type output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index fd2d280c5bc..7dce369a24f 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -137,7 +137,7 @@ std::unique_ptr extract_all_record( // Get the match counts for each string. // This column will become the output lists child offsets column. - auto offsets = count_matches(*d_strings, *d_prog, stream, mr); + auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr); auto d_offsets = offsets->mutable_view().data(); // Compute null output rows diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 2f35a7e5ef5..323ad2cbc09 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -119,7 +119,7 @@ std::unique_ptr findall(strings_column_view const& input, reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); auto const d_strings = column_device_view::create(input.parent(), stream); - auto find_counts = count_matches(*d_strings, *d_prog, stream); + auto find_counts = count_matches(*d_strings, *d_prog, strings_count + 1, stream); auto d_find_counts = find_counts->view().data(); size_type const columns_count = thrust::reduce( diff --git a/cpp/src/strings/search/findall_record.cu b/cpp/src/strings/search/findall_record.cu index 7fb5982b307..46155bd7cf5 100644 --- a/cpp/src/strings/search/findall_record.cu +++ b/cpp/src/strings/search/findall_record.cu @@ -117,7 +117,7 @@ std::unique_ptr findall_record( reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); // Create lists offsets column - auto offsets = count_matches(*d_strings, *d_prog, stream, mr); + auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr); auto d_offsets = offsets->mutable_view().data(); // Convert counts into offsets diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index 286492e53c5..3ec6df058c6 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -225,7 +225,7 @@ std::unique_ptr
split_re(strings_column_view const& input, auto d_strings = column_device_view::create(input.parent(), stream); // count the number of delimiters matched in each string - auto offsets = count_matches(*d_strings, *d_prog, stream, rmm::mr::get_current_device_resource()); + auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream); auto offsets_view = offsets->mutable_view(); auto d_offsets = offsets_view.data(); @@ -287,7 +287,7 @@ std::unique_ptr split_record_re(strings_column_view const& input, auto d_strings = column_device_view::create(input.parent(), stream); // count the number of delimiters matched in each string - auto offsets = count_matches(*d_strings, *d_prog, stream, mr); + auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr); auto offsets_view = offsets->mutable_view(); // get the split tokens from the input column; this also converts the counts into offsets