Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Port NVStrings regex contains ops #3292

Merged
merged 58 commits into from
Jan 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
aba783a
Merge branch 'port-nvs-is-chars-types' into port-nvs-regex-contains
davidwendt Nov 4, 2019
14978a4
initial regex code
davidwendt Nov 4, 2019
7c966f6
Merge branch 'port-nvs-is-chars-types' into port-nvs-regex-contains
davidwendt Nov 4, 2019
94434ae
initial contains_re api
davidwendt Nov 4, 2019
4f3eacf
add contains_re test
davidwendt Nov 4, 2019
f6becfb
fix merge conflict
davidwendt Nov 4, 2019
b3110bb
Merge branch 'port-nvs-is-chars-types' into port-nvs-regex-contains
davidwendt Nov 5, 2019
e63c29f
updated changelog
davidwendt Nov 5, 2019
8eba349
add matches_re api
davidwendt Nov 5, 2019
57a50da
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Nov 6, 2019
4046dd1
Merge branch 'port-nvs-is-chars-types' into port-nvs-regex-contains
davidwendt Nov 6, 2019
93c4be4
place-holder count_re api
davidwendt Nov 6, 2019
a745489
fix merge conflicts
davidwendt Nov 8, 2019
96bc112
added count_re API
davidwendt Nov 8, 2019
4863dcc
fix merge conflicts
davidwendt Nov 11, 2019
2a7bf25
improve Reprog scope
davidwendt Nov 11, 2019
98df936
fix merge conflicts
davidwendt Nov 13, 2019
64a73a2
factored out alloc-relist
davidwendt Nov 13, 2019
819e132
fix merge conflicts
davidwendt Nov 15, 2019
bf525d0
moved string-to-char32 to internal namespace utility
davidwendt Nov 15, 2019
b19bf9d
fix merge conflicts
davidwendt Nov 15, 2019
6157b20
update comments
davidwendt Nov 15, 2019
385cd70
remove unneeded utility change
davidwendt Nov 15, 2019
4f8bf3b
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Nov 15, 2019
f95db96
fix merge conflicts
davidwendt Nov 18, 2019
fda83f4
Merge branch 'port-nvs-regex-contains' of github.com:davidwendt/cudf …
davidwendt Nov 18, 2019
68ced06
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Nov 18, 2019
a172d29
fix merge errors
davidwendt Nov 18, 2019
7a2358b
fix merge conflicts
davidwendt Nov 19, 2019
24e3891
correct merge mistake
davidwendt Nov 19, 2019
8a58918
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Nov 19, 2019
44d89e9
fix merge conflicts
davidwendt Nov 20, 2019
8978c7a
fix comments
davidwendt Nov 20, 2019
19ee203
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Nov 20, 2019
6e2249c
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Nov 22, 2019
ac8b46c
change test from .cu to .cpp
davidwendt Nov 22, 2019
a4bb3c6
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Nov 25, 2019
bf612d4
fix include header
davidwendt Nov 25, 2019
96d89f7
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Dec 2, 2019
6b60fdf
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Dec 2, 2019
826d072
fix merge conflict
davidwendt Dec 3, 2019
66dcbc7
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Dec 4, 2019
8d08dfa
lowercase class names
davidwendt Dec 4, 2019
d80e869
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Dec 5, 2019
daa7b1b
fix class names; update comments
davidwendt Dec 5, 2019
482f0c0
Merge branch 'branch-0.11' into port-nvs-regex-contains
davidwendt Dec 5, 2019
cece52a
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Dec 6, 2019
484c5d7
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Dec 10, 2019
6ed8fb6
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Dec 11, 2019
922c175
change some class names
davidwendt Dec 11, 2019
f854585
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Dec 17, 2019
3e35d5d
update for-loop with any-of; also style changes
davidwendt Dec 17, 2019
6796765
initialize uninitialized variables
davidwendt Dec 17, 2019
25b0f9d
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Dec 20, 2019
22cf14f
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Jan 2, 2020
e81ac81
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Jan 7, 2020
2b07837
update cast in contains_test.cpp
davidwendt Jan 7, 2020
4e295fe
Merge branch 'branch-0.12' into port-nvs-regex-contains
davidwendt Jan 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
- PR #3314 Drop `cython` from run requirements
- PR #3301 Add tests for empty column wrapper.
- PR #3294 Update to arrow-cpp and pyarrow 0.15.1
- PR #3292 Port NVStrings regex contains function
- PR #3310 Add `row_hasher` and `element_hasher` utilities
- PR #3272 Support non-default streams when creating/destroying hash maps
- PR #3286 Clean up the starter code on README
Expand Down
3 changes: 3 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ add_library(cudf
src/strings/case.cu
src/strings/char_types/char_types.cu
src/strings/combine.cu
src/strings/contains.cu
src/strings/convert/convert_booleans.cu
src/strings/convert/convert_datetime.cu
src/strings/convert/convert_floats.cu
Expand All @@ -562,6 +563,8 @@ add_library(cudf
src/strings/find.cu
src/strings/find_multiple.cu
src/strings/padding.cu
src/strings/regex/regcomp.cpp
src/strings/regex/regexec.cu
src/strings/replace/replace.cu
src/strings/sorting/sorting.cu
src/strings/split/split.cu
Expand Down
90 changes: 90 additions & 0 deletions cpp/include/cudf/strings/contains.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/strings/strings_column_view.hpp>
#include <cudf/column/column.hpp>

namespace cudf
{
namespace strings
{

/**
* @brief Returns a boolean column identifying rows which
* match the given regex pattern.
*
* ```
* s = ["abc","123","def456"]
* r = contains(s,"\\d+")
* r is now [false, true, true]
* ```
*
* Any null string entries return corresponding null output column entries.
*
* @param strings Strings instance for this operation.
* @param pattern Regex pattern to match to each string.
* @param mr Resource for allocating device memory.
* @return New column of boolean results for each string.
*/
std::unique_ptr<column> contains_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/**
* @brief Returns a boolean column identifying rows which
* matching the given regex pattern but only at the beginning the string.
*
* ```
* s = ["abc","123","def456"]
* r = contains(s,"\\d+")
* r is now [false, true, false]
* ```
*
* Any null string entries return corresponding null output column entries.
*
* @param strings Strings instance for this operation.
* @param pattern Regex pattern to match to each string.
* @param mr Resource for allocating device memory.
* @return New column of boolean results for each string.
*/
std::unique_ptr<column> matches_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/**
* @brief Returns the number of times the given regex pattern
* matches in each string.
*
* ```
* s = ["abc","123","def45"]
* r = contains(s,"\\d")
* r is now [0, 3, 2]
* ```
*
* Any null string entries return corresponding null output column entries.
*
* @param strings Strings instance for this operation.
* @param pattern Regex pattern to match within each string.
* @param mr Resource for allocating device memory.
* @return New INT32 column with counts for each string.
*/
std::unique_ptr<column> count_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

} // namespace strings
} // namespace cudf
17 changes: 8 additions & 9 deletions cpp/src/strings/char_types/is_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

//
// 8-bit flag for each code-point.
// Flags for each character are defined in char_flags.h
//
#define IS_DECIMAL(x) (x & 1)
#define IS_NUMERIC(x) (x & 2)
#define IS_DIGIT(x) (x & 4)
#define IS_ALPHA(x) (x & 8)
#define IS_ALPHANUM(x) (x & 15)
#define IS_SPACE(x) (x & 16)
#define IS_UPPER(x) (x & 32)
#define IS_LOWER(x) (x & 64)
#define IS_DECIMAL(x) ((x) & (1 << 0))
#define IS_NUMERIC(x) ((x) & (1 << 1))
#define IS_DIGIT(x) ((x) & (1 << 2))
#define IS_ALPHA(x) ((x) & (1 << 3))
#define IS_SPACE(x) ((x) & (1 << 4))
#define IS_UPPER(x) ((x) & (1 << 5))
#define IS_LOWER(x) ((x) & (1 << 6))
#define IS_ALPHANUM(x) ((x) & (0x0F))
244 changes: 244 additions & 0 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/null_mask.hpp>
#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/char_types/char_types.hpp>
#include <cudf/strings/contains.hpp>
#include <cudf/wrappers/bool.hpp>
#include <strings/utilities.hpp>
#include <strings/regex/regex.cuh>


namespace cudf
{
namespace strings
{
namespace detail
{
namespace
{

/**
* @brief This functor handles both contains_re and match_re to minimize the number
* of regex calls to find() to be inlined greatly reducing compile time.
*
* The stack is used to keep progress on evaluating the regex instructions on each string.
* So the size of the stack is in proportion to the number of instructions in the given regex pattern.
*
* There are three call types based on the number of regex instructions in the given pattern.
* Small to medium instruction lengths can use the stack effectively though smaller executes faster.
* Longer patterns require global memory.
*
*/
template<size_t stack_size>
struct contains_fn
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
{
reprog_device prog;
column_device_view d_strings;
bool bmatch{false}; // do not make this a template parameter to keep compile times down

__device__ cudf::experimental::bool8 operator()(size_type idx)
{
if( d_strings.is_null(idx) )
return 0;
u_char data1[stack_size], data2[stack_size];
prog.set_stack_mem(data1,data2);
string_view d_str = d_strings.element<string_view>(idx);
int32_t begin = 0;
int32_t end = bmatch ? 1 : d_str.length(); // 1=match only the beginning of the string
return static_cast<experimental::bool8>(prog.find(idx,d_str,begin,end));
}
};

//
std::unique_ptr<column> contains_util( strings_column_view const& strings,
std::string const& pattern,
bool beginning_only = false,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(),
cudaStream_t stream = 0)
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(),stream);
auto d_column = *strings_column;

// compile regex into device object
auto prog = reprog_device::create(pattern,get_character_flags_table(),strings_count,stream);
auto d_prog = *prog;

// create the output column
auto results = make_numeric_column( data_type{BOOL8}, strings_count,
copy_bitmask( strings.parent(), stream, mr), strings.null_count(), stream, mr);
auto d_results = results->mutable_view().data<cudf::experimental::bool8>();

// fill the output column
auto execpol = rmm::exec_policy(stream);
int regex_insts = d_prog.insts_counts();
if( (regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS) )
thrust::transform(execpol->on(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results, contains_fn<RX_STACK_SMALL>{d_prog, d_column, beginning_only} );
else if( regex_insts <= RX_MEDIUM_INSTS )
thrust::transform(execpol->on(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results, contains_fn<RX_STACK_MEDIUM>{d_prog, d_column, beginning_only} );
else
thrust::transform(execpol->on(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results, contains_fn<RX_STACK_LARGE>{d_prog, d_column, beginning_only} );

results->set_null_count(strings.null_count());
return results;
}

} // namespace

std::unique_ptr<column> contains_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(),
cudaStream_t stream = 0)
{
return contains_util(strings, pattern, false, mr, stream);
}

std::unique_ptr<column> matches_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(),
cudaStream_t stream = 0)
{
return contains_util(strings, pattern, true, mr, stream);
}

} // namespace detail

// external APIs

std::unique_ptr<column> contains_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr)
{
return detail::contains_re(strings, pattern, mr);
}

std::unique_ptr<column> matches_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr)
{
return detail::matches_re(strings, pattern, mr);
}

namespace detail
{

namespace
{

/**
* @brief This counts the number of times the regex pattern matches in each string.
*
*/
template<size_t stack_size>
struct count_fn
{
reprog_device prog;
column_device_view d_strings;

__device__ int32_t operator()(unsigned int idx)
{
u_char data1[stack_size], data2[stack_size];
prog.set_stack_mem(data1,data2);
if( d_strings.is_null(idx) )
return 0;
string_view d_str = d_strings.element<string_view>(idx);
int32_t find_count = 0;
size_type nchars = d_str.length();
size_type begin = 0;
while( begin <= nchars )
{
auto end = nchars;
if( prog.find(idx,d_str,begin,end) <=0 )
break;
++find_count;
begin = end > begin ? end : begin + 1;
}
return find_count;
}
};

}

std::unique_ptr<column> count_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource(),
cudaStream_t stream = 0)
{
auto strings_count = strings.size();
auto strings_column = column_device_view::create(strings.parent(),stream);
auto d_column = *strings_column;

// compile regex into device object
auto prog = reprog_device::create(pattern,get_character_flags_table(),strings_count,stream);
auto d_prog = *prog;

// create the output column
auto results = make_numeric_column( data_type{INT32}, strings_count,
copy_bitmask( strings.parent(), stream, mr), strings.null_count(), stream, mr);
auto d_results = results->mutable_view().data<int32_t>();

// fill the output column
auto execpol = rmm::exec_policy(stream);
int regex_insts = d_prog.insts_counts();
if( (regex_insts > MAX_STACK_INSTS) || (regex_insts <= RX_SMALL_INSTS) )
thrust::transform(execpol->on(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results, count_fn<RX_STACK_SMALL>{d_prog, d_column} );
else if( regex_insts <= RX_MEDIUM_INSTS )
thrust::transform(execpol->on(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results, count_fn<RX_STACK_MEDIUM>{d_prog, d_column} );
else
thrust::transform(execpol->on(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
d_results, count_fn<RX_STACK_LARGE>{d_prog, d_column} );

results->set_null_count(strings.null_count());
return results;

}

} // namespace detail

// external API

std::unique_ptr<column> count_re( strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr)
{
return detail::count_re(strings, pattern, mr);
}

} // namespace strings
} // namespace cudf
Loading