Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor setting stack size in regex code #8358

Merged
merged 10 commits into from
Jun 7, 2021
30 changes: 19 additions & 11 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,13 +54,11 @@ struct contains_fn {
__device__ bool operator()(size_type idx)
{
if (d_strings.is_null(idx)) return 0;
u_char data1[stack_size], data2[stack_size];
prog.set_stack_mem(data1, data2);
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = bmatch ? 1 // match only the beginning of the string;
: -1; // this handles empty strings too
return static_cast<bool>(prog.find(idx, d_str, begin, end));
return static_cast<bool>(prog.find<stack_size>(idx, d_str, begin, end));
}
};

Expand Down Expand Up @@ -91,7 +89,7 @@ std::unique_ptr<column> contains_util(

// fill the output column
int regex_insts = d_prog.insts_counts();
if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
Expand All @@ -103,12 +101,18 @@ std::unique_ptr<column> contains_util(
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_MEDIUM>{d_prog, d_column, beginning_only});
else
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_LARGE>{d_prog, d_column, beginning_only});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
contains_fn<RX_STACK_ANY>{d_prog, d_column, beginning_only});

results->set_null_count(strings.null_count());
return results;
Expand Down Expand Up @@ -166,16 +170,14 @@ struct count_fn {

__device__ int32_t operator()(unsigned int idx)
{
u_char data1[stack_size], data2[stack_size];
prog.set_stack_mem(data1, data2);
if (d_strings.is_null(idx)) return 0;
string_view d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();
int32_t find_count = 0;
int32_t begin = 0;
while (begin < nchars) {
auto end = static_cast<int32_t>(nchars);
if (prog.find(idx, d_str, begin, end) <= 0) break;
if (prog.find<stack_size>(idx, d_str, begin, end) <= 0) break;
++find_count;
begin = end > begin ? end : begin + 1;
}
Expand Down Expand Up @@ -210,7 +212,7 @@ std::unique_ptr<column> count_re(

// fill the output column
int regex_insts = d_prog.insts_counts();
if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
Expand All @@ -222,12 +224,18 @@ std::unique_ptr<column> count_re(
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_MEDIUM>{d_prog, d_column});
else
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_LARGE>{d_prog, d_column});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results,
count_fn<RX_STACK_ANY>{d_prog, d_column});

results->set_null_count(strings.null_count());
return results;
Expand Down
16 changes: 10 additions & 6 deletions cpp/src/strings/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,13 @@ struct extract_fn {

__device__ string_index_pair operator()(size_type idx)
{
u_char data1[stack_size], data2[stack_size];
prog.set_stack_mem(data1, data2);
if (d_strings.is_null(idx)) return string_index_pair{nullptr, 0};
string_view d_str = d_strings.element<string_view>(idx);
string_index_pair result{nullptr, 0};
int32_t begin = 0;
int32_t end = -1; // handles empty strings automatically
if ((prog.find(idx, d_str, begin, end) > 0) &&
(prog.extract(idx, d_str, begin, end, column_index) > 0)) {
if ((prog.find<stack_size>(idx, d_str, begin, end) > 0) &&
(prog.extract<stack_size>(idx, d_str, begin, end, column_index) > 0)) {
auto offset = d_str.byte_offset(begin);
// build index-pair
result = string_index_pair{d_str.data() + offset, d_str.byte_offset(end) - offset};
Expand Down Expand Up @@ -94,7 +92,7 @@ std::unique_ptr<table> extract(
for (int32_t column_index = 0; column_index < groups; ++column_index) {
rmm::device_uvector<string_index_pair> indices(strings_count, stream);

if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
Expand All @@ -106,12 +104,18 @@ std::unique_ptr<table> extract(
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_MEDIUM>{d_prog, d_strings, column_index});
else
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_LARGE>{d_prog, d_strings, column_index});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
extract_fn<RX_STACK_ANY>{d_prog, d_strings, column_index});

results.emplace_back(make_strings_column(indices, stream, mr));
}
Expand Down
30 changes: 20 additions & 10 deletions cpp/src/strings/findall.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace {
* @brief This functor handles extracting matched strings by applying the compiled regex pattern
* and creating string_index_pairs for all the substrings.
*/
template <size_t stack_size>
template <int stack_size>
struct findall_fn {
column_device_view const d_strings;
reprog_device prog;
Expand All @@ -64,17 +64,14 @@ struct findall_fn {
string_index_pair result{nullptr, 0};
if (d_strings.is_null(idx) || (d_counts && (column_index >= d_counts[idx])))
return findall_result{0, result};
u_char data1[stack_size];
u_char data2[stack_size];
prog.set_stack_mem(data1, data2);
string_view d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();
int32_t spos = 0;
int32_t epos = static_cast<int32_t>(nchars);
size_type column_count = 0;
while (spos <= nchars) {
if (prog.find(idx, d_str, spos, epos) <= 0) break; // no more matches found
if (column_count == column_index) break; // found our column
if (prog.find<stack_size>(idx, d_str, spos, epos) <= 0) break; // no more matches found
if (column_count == column_index) break; // found our column
spos = epos > spos ? epos : spos + 1;
epos = static_cast<int32_t>(nchars);
++column_count;
Expand Down Expand Up @@ -129,7 +126,7 @@ std::unique_ptr<table> findall_re(
rmm::device_uvector<size_type> find_counts(strings_count, stream);
auto d_find_counts = find_counts.data();

if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
Expand All @@ -141,12 +138,18 @@ std::unique_ptr<table> findall_re(
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_MEDIUM>{*d_strings, *d_prog});
else
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_LARGE>{*d_strings, *d_prog});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_find_counts,
findall_count_fn<RX_STACK_ANY>{*d_strings, *d_prog});

std::vector<std::unique_ptr<column>> results;

Expand All @@ -167,7 +170,7 @@ std::unique_ptr<table> findall_re(
for (int32_t column_index = 0; column_index < columns; ++column_index) {
rmm::device_uvector<string_index_pair> indices(strings_count, stream);

if ((regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS))
if (regex_insts <= RX_SMALL_INSTS)
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
Expand All @@ -181,13 +184,20 @@ std::unique_ptr<table> findall_re(
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
findall_fn<RX_STACK_MEDIUM>{*d_strings, *d_prog, column_index, d_find_counts});
else
else if (regex_insts <= RX_LARGE_INSTS)
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
findall_fn<RX_STACK_LARGE>{*d_strings, *d_prog, column_index, d_find_counts});
else
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
indices.begin(),
findall_fn<RX_STACK_ANY>{*d_strings, *d_prog, column_index, d_find_counts});

//
results.emplace_back(make_strings_column(indices.begin(), indices.end(), stream, mr));
}
Expand Down
53 changes: 27 additions & 26 deletions cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,24 @@ struct reljunk;
struct reinst;
class reprog;

constexpr int32_t RX_STACK_SMALL = 112; ///< fastest stack size
constexpr int32_t RX_STACK_MEDIUM = 1104; ///< faster stack size
constexpr int32_t RX_STACK_LARGE = 10128; ///< fast stack size
constexpr int32_t RX_STACK_ANY = 8; ///< slowest: uses global memory

/**
* @brief Mapping the number of instructions to device code stack memory size.
*
* ```
* 10128 ≈ 1000 instructions
* Formula is based on relist::data_size_for() calculation;
* Stack ≈ (8+2)*x + (x/8) = 10.125x < 11x where x is number of instructions
* ```
*/
constexpr int32_t RX_SMALL_INSTS = (RX_STACK_SMALL / 11);
constexpr int32_t RX_MEDIUM_INSTS = (RX_STACK_MEDIUM / 11);
constexpr int32_t RX_LARGE_INSTS = (RX_STACK_LARGE / 11);

/**
* @brief Regex class stored on the device and executed by reprog_device.
*
Expand Down Expand Up @@ -99,14 +117,7 @@ class reprog_device {
/**
* @brief Returns the number of regex groups found in the expression.
*/
int32_t group_counts() const { return _num_capturing_groups; }

/**
* @brief This sets up the memory used for keeping track of the regex progress.
*
* Call this for each string before calling find or extract.
*/
__device__ inline void set_stack_mem(u_char* s1, u_char* s2);
__host__ __device__ inline int32_t group_counts() const { return _num_capturing_groups; }

/**
* @brief Returns the regex instruction object for a given index.
Expand All @@ -126,6 +137,7 @@ class reprog_device {
/**
* @brief Does a find evaluation using the compiled expression on the given string.
*
* @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`.
* @param idx The string index used for mapping the state memory for this string in global memory
* (if necessary).
* @param d_str The string to search.
Expand All @@ -135,6 +147,7 @@ class reprog_device {
* matching in the string.
* @return Returns 0 if no match is found.
*/
template <int stack_size>
__device__ inline int32_t find(int32_t idx,
string_view const& d_str,
int32_t& begin,
Expand All @@ -145,18 +158,20 @@ class reprog_device {
*
* This will find a specific match within the string when more than match occurs.
*
* @tparam stack_size One of the `RX_STACK_` values based on the `insts_count`.
* @param idx The string index used for mapping the state memory for this string in global memory
* (if necessary).
* @param d_str The string to search.
* @param[in,out] begin Position index to begin the search. If found, returns the position found
* in the string.
* @param[in,out] end Position index to end the search. If found, returns the last position
* matching in the string.
* @param column The specific instance to return if more than one match is found.
* @param group_id The specific instance to return if more than one match is found.
* @return Returns 0 if no match is found.
*/
template <int stack_size>
__device__ inline int32_t extract(
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t column);
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t group_id);

private:
int32_t _startinst_id, _num_capturing_groups;
Expand All @@ -166,8 +181,6 @@ class reprog_device {
int32_t* _startinst_ids{}; // array of start instruction ids
reclass_device* _classes{}; // array of regex classes
void* _relists_mem{}; // runtime relist memory for regexec
u_char* _stack_mem1{}; // memory for relist object 1
u_char* _stack_mem2{}; // memory for relist object 2

/**
* @brief Executes the regex pattern on the given string.
Expand All @@ -178,25 +191,13 @@ class reprog_device {
/**
* @brief Utility wrapper to setup state memory structures for calling regexec
*/
template <int stack_size>
__device__ inline int32_t call_regexec(
int32_t idx, string_view const& d_str, int32_t& begin, int32_t& end, int32_t groupid = 0);

reprog_device(reprog&); // must use create()
};

// 10128 ≈ 1000 instructions
// Formula is based on relist::data_size_for() calculation;
// Stack ≈ (8+2)*x + (x/8) = 10.125x < 11x where x is number of instructions
constexpr int32_t MAX_STACK_INSTS = 1000;

constexpr int32_t RX_STACK_SMALL = 112;
constexpr int32_t RX_STACK_MEDIUM = 1104;
constexpr int32_t RX_STACK_LARGE = 10128;

constexpr int32_t RX_SMALL_INSTS = (RX_STACK_SMALL / 11);
constexpr int32_t RX_MEDIUM_INSTS = (RX_STACK_MEDIUM / 11);
constexpr int32_t RX_LARGE_INSTS = (RX_STACK_LARGE / 11);

} // namespace detail
} // namespace strings
} // namespace cudf
Expand Down
Loading