Skip to content

Commit

Permalink
Add regex flags to strings findall functions (#10208)
Browse files Browse the repository at this point in the history
Add the `regex_flags` parameter to the strings `findall()` and `findall_record` functions so that matching regex patterns is consistent other strings regex APIs.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Bradley Dice (https://github.com/bdice)

URL: #10208
  • Loading branch information
davidwendt authored Feb 8, 2022
1 parent 10faad9 commit bd98bfe
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 33 deletions.
15 changes: 10 additions & 5 deletions cpp/include/cudf/strings/findall.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*/
#pragma once

#include <cudf/strings/regex/flags.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/table/table.hpp>

Expand Down Expand Up @@ -47,14 +48,16 @@ namespace strings {
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param input Strings instance for this operation.
* @param pattern Regex pattern to match within each string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned table's device memory.
* @return New table of strings columns.
*/
std::unique_ptr<table> findall(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -77,14 +80,16 @@ std::unique_ptr<table> findall(
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param input Strings instance for this operation.
* @param pattern Regex pattern to match within each string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New lists column of strings.
*/
std::unique_ptr<column> findall_record(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
Expand Down
18 changes: 10 additions & 8 deletions cpp/src/strings/search/findall.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -110,17 +110,18 @@ struct findall_count_fn : public findall_fn<stack_size> {

//
std::unique_ptr<table> findall(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto const strings_count = strings.size();
auto const d_strings = column_device_view::create(strings.parent(), stream);
auto const strings_count = input.size();
auto const d_strings = column_device_view::create(input.parent(), stream);

auto const d_flags = detail::get_character_flags_table();
// compile regex into device object
auto const d_prog = reprog_device::create(pattern, d_flags, strings_count, stream);
auto const d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto const regex_insts = d_prog->insts_counts();

rmm::device_uvector<size_type> find_counts(strings_count, stream);
Expand Down Expand Up @@ -205,12 +206,13 @@ std::unique_ptr<table> findall(

// external API

std::unique_ptr<table> findall(strings_column_view const& strings,
std::unique_ptr<table> findall(strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::findall(strings, pattern, rmm::cuda_stream_default, mr);
return detail::findall(input, pattern, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/strings/search/findall_record.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,18 @@ struct findall_fn {

//
std::unique_ptr<column> findall_record(
strings_column_view const& strings,
strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto const strings_count = strings.size();
auto const d_strings = column_device_view::create(strings.parent(), stream);
auto const strings_count = input.size();
auto const d_strings = column_device_view::create(input.parent(), stream);

// compile regex into device object
auto const d_prog =
reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);

// Create lists offsets column
auto offsets = count_matches(*d_strings, *d_prog, stream, mr);
Expand Down Expand Up @@ -159,12 +160,13 @@ std::unique_ptr<column> findall_record(

// external API

std::unique_ptr<column> findall_record(strings_column_view const& strings,
std::unique_ptr<column> findall_record(strings_column_view const& input,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::findall_record(strings, pattern, rmm::cuda_stream_default, mr);
return detail::findall_record(input, pattern, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
45 changes: 45 additions & 0 deletions cpp/tests/strings/findall_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,51 @@ TEST_F(StringsFindallTests, FindallRecord)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}

TEST_F(StringsFindallTests, Multiline)
{
cudf::test::strings_column_wrapper input({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"});
auto view = cudf::strings_column_view(input);

{
auto results = cudf::strings::findall(view, "(^abc$)", cudf::strings::regex_flags::MULTILINE);
auto col0 =
cudf::test::strings_column_wrapper({"abc", "abc", "abc", "", "abc"}, {1, 1, 1, 0, 1});
auto col1 = cudf::test::strings_column_wrapper({"abc", "", "", "", ""}, {1, 0, 0, 0, 0});
auto expected = cudf::table_view({col0, col1});
CUDF_TEST_EXPECT_TABLES_EQUAL(results->view(), expected);
}
{
auto results =
cudf::strings::findall_record(view, "(^abc$)", cudf::strings::regex_flags::MULTILINE);
bool valids[] = {1, 1, 1, 0, 1};
using LCW = cudf::test::lists_column_wrapper<cudf::string_view>;
LCW expected({LCW{"abc", "abc"}, LCW{"abc"}, LCW{"abc"}, LCW{}, LCW{"abc"}}, valids);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}
}

TEST_F(StringsFindallTests, DotAll)
{
cudf::test::strings_column_wrapper input({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""});
auto view = cudf::strings_column_view(input);

{
auto results = cudf::strings::findall(view, "(b.*f)", cudf::strings::regex_flags::DOTALL);
auto col0 =
cudf::test::strings_column_wrapper({"bc\nfa\nef", "bbc\nfff", "bcdef", ""}, {1, 1, 1, 0});
auto expected = cudf::table_view({col0});
CUDF_TEST_EXPECT_TABLES_EQUAL(results->view(), expected);
}
{
auto results =
cudf::strings::findall_record(view, "(b.*f)", cudf::strings::regex_flags::DOTALL);
bool valids[] = {1, 1, 1, 0};
using LCW = cudf::test::lists_column_wrapper<cudf::string_view>;
LCW expected({LCW{"bc\nfa\nef"}, LCW{"bbc\nfff"}, LCW{"bcdef"}, LCW{}}, valids);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}
}

TEST_F(StringsFindallTests, MediumRegex)
{
// This results in 15 regex instructions and falls in the 'medium' range.
Expand Down
7 changes: 5 additions & 2 deletions python/cudf/cudf/_lib/cpp/strings/findall.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ from libcpp.string cimport string

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.strings.contains cimport regex_flags
from cudf._lib.cpp.table.table cimport table


cdef extern from "cudf/strings/findall.hpp" namespace "cudf::strings" nogil:

cdef unique_ptr[table] findall(
const column_view& source_strings,
const string& pattern) except +
const string& pattern,
regex_flags flags) except +

cdef unique_ptr[column] findall_record(
const column_view& source_strings,
const string& pattern) except +
const string& pattern,
regex_flags flags) except +
16 changes: 11 additions & 5 deletions python/cudf/cudf/_lib/strings/findall.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.

from libc.stdint cimport uint32_t
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.utility cimport move
Expand All @@ -8,6 +9,7 @@ from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.scalar.scalar cimport string_scalar
from cudf._lib.cpp.strings.contains cimport regex_flags
from cudf._lib.cpp.strings.findall cimport (
findall as cpp_findall,
findall_record as cpp_findall_record,
Expand All @@ -17,7 +19,7 @@ from cudf._lib.scalar cimport DeviceScalar
from cudf._lib.utils cimport data_from_unique_ptr


def findall(Column source_strings, pattern):
def findall(Column source_strings, object pattern, uint32_t flags):
"""
Returns data with all non-overlapping matches of `pattern`
in each string of `source_strings`.
Expand All @@ -26,11 +28,13 @@ def findall(Column source_strings, pattern):
cdef column_view source_view = source_strings.view()

cdef string pattern_string = <string>str(pattern).encode()
cdef regex_flags c_flags = <regex_flags>flags

with nogil:
c_result = move(cpp_findall(
source_view,
pattern_string
pattern_string,
c_flags
))

return data_from_unique_ptr(
Expand All @@ -39,7 +43,7 @@ def findall(Column source_strings, pattern):
)


def findall_record(Column source_strings, pattern):
def findall_record(Column source_strings, object pattern, uint32_t flags):
"""
Returns data with all non-overlapping matches of `pattern`
in each string of `source_strings` as a lists column.
Expand All @@ -48,11 +52,13 @@ def findall_record(Column source_strings, pattern):
cdef column_view source_view = source_strings.view()

cdef string pattern_string = <string>str(pattern).encode()
cdef regex_flags c_flags = <regex_flags>flags

with nogil:
c_result = move(cpp_findall_record(
source_view,
pattern_string
pattern_string,
c_flags
))

return Column.from_unique_ptr(move(c_result))
16 changes: 12 additions & 4 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -3410,6 +3410,8 @@ def findall(
----------
pat : str
Pattern or regular expression.
flags : int, default 0 (no flags)
Flags to pass through to the regex engine (e.g. re.MULTILINE)
Returns
-------
Expand All @@ -3419,7 +3421,8 @@ def findall(
Notes
-----
`flags` parameter is currently not supported.
The `flags` parameter currently only supports re.DOTALL and
re.MULTILINE.
Examples
--------
Expand Down Expand Up @@ -3462,10 +3465,15 @@ def findall(
1 <NA> <NA>
2 b b
"""
if flags != 0:
raise NotImplementedError("`flags` parameter is not yet supported")
if isinstance(pat, re.Pattern):
flags = pat.flags & ~re.U
pat = pat.pattern
if not _is_supported_regex_flags(flags):
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

data, index = libstrings.findall(self._column, pat)
data, index = libstrings.findall(self._column, pat, flags)
return self._return_or_inplace(
cudf.core.frame.Frame(data, index), expand=expand
)
Expand Down
15 changes: 12 additions & 3 deletions python/cudf/cudf/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,14 +1775,23 @@ def test_string_count(data, pat, flags):


def test_string_findall():
ps = pd.Series(["Lion", "Monkey", "Rabbit"])
gs = cudf.Series(["Lion", "Monkey", "Rabbit"])
test_data = ["Lion", "Monkey", "Rabbit", "Don\nkey"]
ps = pd.Series(test_data)
gs = cudf.Series(test_data)

assert_eq(ps.str.findall("Monkey")[1][0], gs.str.findall("Monkey")[0][1])
assert_eq(ps.str.findall("on")[0][0], gs.str.findall("on")[0][0])
assert_eq(ps.str.findall("on")[1][0], gs.str.findall("on")[0][1])
assert_eq(ps.str.findall("on$")[0][0], gs.str.findall("on$")[0][0])
assert_eq(ps.str.findall("b")[2][1], gs.str.findall("b")[1][2])
assert_eq(ps.str.findall("on$")[0][0], gs.str.findall("on$")[0][0])
assert_eq(
ps.str.findall("on$", re.MULTILINE)[3][0],
gs.str.findall("on$", re.MULTILINE)[0][3],
)
assert_eq(
ps.str.findall("o.*k", re.DOTALL)[3][0],
gs.str.findall("o.*k", re.DOTALL)[0][3],
)


def test_string_replace_multi():
Expand Down

0 comments on commit bd98bfe

Please sign in to comment.