Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cudf::strings::extract_all API #9909

Merged
merged 11 commits into from
Jan 5, 2022
3 changes: 2 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ add_library(
src/strings/copying/concatenate.cu
src/strings/copying/copying.cu
src/strings/copying/shift.cu
src/strings/extract.cu
src/strings/extract/extract.cu
src/strings/extract/extract_all.cu
src/strings/filling/fill.cu
src/strings/filter_chars.cu
src/strings/findall.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace cudf {
namespace strings {
namespace detail {

/**
* @brief Basic type expected for iterators passed to `make_strings_column` that represent string
* data in device memory.
*/
using string_index_pair = thrust::pair<const char*, size_type>;

/**
* @brief Average string byte-length threshold for deciding character-level
* vs. row-level parallel algorithm.
Expand Down Expand Up @@ -64,8 +70,6 @@ std::unique_ptr<column> make_strings_column(IndexPairIterator begin,
size_type strings_count = thrust::distance(begin, end);
if (strings_count == 0) return make_empty_column(type_id::STRING);

using string_index_pair = thrust::pair<const char*, size_type>;

// check total size is not too large for cudf column
auto size_checker = [] __device__(string_index_pair const& item) {
return (item.first != nullptr) ? item.second : 0;
Expand Down
50 changes: 42 additions & 8 deletions cpp/include/cudf/strings/extract.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,20 +27,21 @@ namespace strings {
*/

/**
* @brief Returns a vector of strings columns for each matching group specified in the given regular
* expression pattern.
* @brief Returns a table of strings columns where each column corresponds to the matching
* group specified in the given regular expression pattern.
*
* All the strings for the first group will go in the first output column; the second group
* go in the second column and so on. Null entries are added if the string does match.
* go in the second column and so on. Null entries are added to the columns in row `i` if
* the string at row `i` does not match.
*
* Any null string entries return corresponding null output column entries.
*
* @code{.pseudo}
* Example:
* s = ["a1","b2","c3"]
* r = extract(s,"([ab])(\\d)")
* r is now [["a","b",null],
* ["1","2",null]]
* s = ["a1", "b2", "c3"]
* r = extract(s, "([ab])(\\d)")
* r is now [ ["a", "b", null],
* ["1", "2", null] ]
* @endcode
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
Expand All @@ -55,6 +56,39 @@ std::unique_ptr<table> extract(
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns a lists column of strings where each string column row corresponds to the
* matching group specified in the given regular expression pattern.
*
* All the matching groups for the first row will go in the first row output column; the second
* row results will go into the second row output column and so on.
*
* A null output row will result if the corresponding input string row does not match or
* that input row is null.
*
* @code{.pseudo}
* Example:
* s = ["a1 b4", "b2", "c3 a5", "b", null]
* r = extract_all(s,"([ab])(\\d)")
* r is now [ ["a", "1", "b", "4"],
* ["b", "2"],
* ["a", "5"],
* null,
* null ]
* @endcode
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param pattern The regular expression pattern with group indicators.
* @param mr Device memory resource used to allocate any returned device memory.
* @return Lists column containing strings columns extracted from the input column.
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
*/
std::unique_ptr<column> extract_all(
strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
} // namespace strings
} // namespace cudf
105 changes: 105 additions & 0 deletions cpp/src/strings/count_matches.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/strings/string_view.cuh>

#include <strings/regex/regex.cuh>
davidwendt marked this conversation as resolved.
Show resolved Hide resolved

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {

/**
* @brief Functor counts the total matches to the given regex in each string.
*/
template <int stack_size>
struct count_matches_fn {
column_device_view const d_strings;
reprog_device prog;

__device__ size_type operator()(size_type idx)
{
if (d_strings.is_null(idx)) { return 0; }
size_type count = 0;
auto const d_str = d_strings.element<string_view>(idx);

int32_t begin = 0;
int32_t end = d_str.length();
while ((begin < end) && (prog.find<stack_size>(idx, d_str, begin, end) > 0)) {
++count;
begin = end;
end = d_str.length();
}
return count;
}
};

/**
* @brief Returns a column of regex match counts for each string in the given column.
*
* A null entry will result in a zero count for that output row.
*
* @param d_strings Device view of the input strings column.
* @param d_prog Regex instance to evaluate on each string.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr<column> count_matches(
column_device_view const& d_strings,
reprog_device const& d_prog,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
// Create output column
auto counts = make_numeric_column(
data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr);
auto d_counts = counts->mutable_view().data<offset_type>();

auto begin = thrust::make_counting_iterator<size_type>(0);
auto end = thrust::make_counting_iterator<size_type>(d_strings.size());

// Count matches
auto const regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS) {
count_matches_fn<RX_STACK_SMALL> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
count_matches_fn<RX_STACK_MEDIUM> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_LARGE_INSTS) {
count_matches_fn<RX_STACK_LARGE> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else {
count_matches_fn<RX_STACK_ANY> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
}

return counts;
}

} // namespace detail
} // namespace strings
} // namespace cudf
File renamed without changes.
Loading