Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex flags to strings extract function #10192

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cpp/include/cudf/strings/extract.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*/
#pragma once

#include <cudf/strings/regex/flags.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/table/table.hpp>

Expand Down Expand Up @@ -48,12 +49,14 @@ namespace strings {
*
* @param strings Strings instance for this operation.
* @param pattern The regular expression pattern with group indicators.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned table's device memory.
* @return Columns of strings extracted from the input column.
*/
std::unique_ptr<table> extract(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -69,7 +72,7 @@ std::unique_ptr<table> extract(
* @code{.pseudo}
* Example:
* s = ["a1 b4", "b2", "c3 a5", "b", null]
* r = extract_all(s,"([ab])(\\d)")
* r = extract_all_record(s,"([ab])(\\d)")
* r is now [ ["a", "1", "b", "4"],
* ["b", "2"],
* ["a", "5"],
Expand All @@ -81,12 +84,14 @@ std::unique_ptr<table> extract(
*
* @param strings Strings instance for this operation.
* @param pattern The regular expression pattern with group indicators.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate any returned device memory.
* @return Lists column containing strings extracted from the input column.
*/
std::unique_ptr<column> extract_all(
std::unique_ptr<column> extract_all_record(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/strings/extract/extract.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -83,6 +83,7 @@ struct extract_fn {
std::unique_ptr<table> extract(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
Expand All @@ -91,7 +92,8 @@ std::unique_ptr<table> extract(
auto const d_strings = *strings_column;

// compile regex into device object
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;
// extract should include groups
auto const groups = d_prog.group_counts();
Expand Down Expand Up @@ -150,10 +152,11 @@ std::unique_ptr<table> extract(

std::unique_ptr<table> extract(strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::extract(strings, pattern, rmm::cuda_stream_default, mr);
return detail::extract(strings, pattern, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
17 changes: 10 additions & 7 deletions cpp/src/strings/extract/extract_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,23 @@ struct extract_fn {
} // namespace

/**
* @copydoc cudf::strings::extract_all
* @copydoc cudf::strings::extract_all_record
*
* @param stream CUDA stream used for device memory operations and kernel launches.
*/
std::unique_ptr<column> extract_all(
std::unique_ptr<column> extract_all_record(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto const strings_count = strings.size();
auto const d_strings = column_device_view::create(strings.parent(), stream);

// Compile regex into device object.
auto d_prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
auto d_prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
// The extract pattern should always include groups.
auto const groups = d_prog->group_counts();
CUDF_EXPECTS(groups > 0, "extract_all requires group indicators in the regex pattern.");
Expand Down Expand Up @@ -179,12 +181,13 @@ std::unique_ptr<column> extract_all(

// external API

std::unique_ptr<column> extract_all(strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr)
std::unique_ptr<column> extract_all_record(strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::extract_all(strings, pattern, rmm::cuda_stream_default, mr);
return detail::extract_all_record(strings, pattern, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
39 changes: 36 additions & 3 deletions cpp/tests/strings/extract_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -150,6 +150,39 @@ TEST_F(StringsExtractTests, ExtractEventTest)
}
}

TEST_F(StringsExtractTests, MultiLine)
{
auto input =
cudf::test::strings_column_wrapper({"abc\nfff\nabc", "fff\nabc\nlll", "abc", "", "abc\n"});
auto view = cudf::strings_column_view(input);

auto results = cudf::strings::extract(view, "(^[a-c]+$)", cudf::strings::regex_flags::MULTILINE);
cudf::test::strings_column_wrapper expected_multiline({"abc", "abc", "abc", "", "abc"},
{1, 1, 1, 0, 1});
auto expected = cudf::table_view{{expected_multiline}};
CUDF_TEST_EXPECT_TABLES_EQUAL(*results, expected);
results = cudf::strings::extract(view, "^([a-c]+)$");
cudf::test::strings_column_wrapper expected_default({"", "", "abc", "", ""}, {0, 0, 1, 0, 0});
expected = cudf::table_view{{expected_default}};
CUDF_TEST_EXPECT_TABLES_EQUAL(*results, expected);
}

TEST_F(StringsExtractTests, DotAll)
{
auto input = cudf::test::strings_column_wrapper({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""});
auto view = cudf::strings_column_view(input);

auto results = cudf::strings::extract(view, "(a.*f)", cudf::strings::regex_flags::DOTALL);
cudf::test::strings_column_wrapper expected_dotall({"abc\nfa\nef", "abbc\nfff", "abcdef", ""},
{1, 1, 1, 0});
auto expected = cudf::table_view{{expected_dotall}};
CUDF_TEST_EXPECT_TABLES_EQUAL(*results, expected);
results = cudf::strings::extract(view, "(a.*f)");
cudf::test::strings_column_wrapper expected_default({"", "", "abcdef", ""}, {0, 0, 1, 0});
expected = cudf::table_view{{expected_default}};
CUDF_TEST_EXPECT_TABLES_EQUAL(*results, expected);
}

TEST_F(StringsExtractTests, EmptyExtractTest)
{
std::vector<const char*> h_strings{nullptr, "AAA", "AAA_A", "AAA_AAA_", "A__", ""};
Expand Down Expand Up @@ -181,7 +214,7 @@ TEST_F(StringsExtractTests, ExtractAllTest)
cudf::test::strings_column_wrapper input(h_input.begin(), h_input.end(), validity);
auto sv = cudf::strings_column_view(input);

auto results = cudf::strings::extract_all(sv, "(\\d+) (\\w+)");
auto results = cudf::strings::extract_all_record(sv, "(\\d+) (\\w+)");

bool valids[] = {true, true, true, false, false, false, true};
using LCW = cudf::test::lists_column_wrapper<cudf::string_view>;
Expand All @@ -201,7 +234,7 @@ TEST_F(StringsExtractTests, Errors)
cudf::test::strings_column_wrapper input({"this column intentionally left blank"});
auto sv = cudf::strings_column_view(input);
EXPECT_THROW(cudf::strings::extract(sv, "\\w+"), cudf::logic_error);
EXPECT_THROW(cudf::strings::extract_all(sv, "\\w+"), cudf::logic_error);
EXPECT_THROW(cudf::strings::extract_all_record(sv, "\\w+"), cudf::logic_error);
}

TEST_F(StringsExtractTests, MediumRegex)
Expand Down
11 changes: 2 additions & 9 deletions python/cudf/cudf/_lib/cpp/strings/contains.pxd
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from libcpp.string cimport string

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.strings.regex_flags cimport regex_flags


cdef extern from "cudf/strings/regex/flags.hpp" \
namespace "cudf::strings" nogil:

ctypedef enum regex_flags:
DEFAULT 'cudf::strings::regex_flags::DEFAULT'
MULTILINE 'cudf::strings::regex_flags::MULTILINE'
DOTALL 'cudf::strings::regex_flags::DOTALL'

cdef extern from "cudf/strings/contains.hpp" namespace "cudf::strings" nogil:

cdef unique_ptr[column] contains_re(
Expand Down
6 changes: 4 additions & 2 deletions python/cudf/cudf/_lib/cpp/strings/extract.pxd
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from libcpp.string cimport string

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.strings.contains cimport regex_flags
from cudf._lib.cpp.table.table cimport table


cdef extern from "cudf/strings/extract.hpp" namespace "cudf::strings" nogil:

cdef unique_ptr[table] extract(
column_view source_strings,
string pattern) except +
string pattern,
regex_flags flags) except +
9 changes: 9 additions & 0 deletions python/cudf/cudf/_lib/cpp/strings/regex_flags.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2022, NVIDIA CORPORATION.

cdef extern from "cudf/strings/regex/flags.hpp" \
namespace "cudf::strings" nogil:

ctypedef enum regex_flags:
DEFAULT 'cudf::strings::regex_flags::DEFAULT'
MULTILINE 'cudf::strings::regex_flags::MULTILINE'
DOTALL 'cudf::strings::regex_flags::DOTALL'
4 changes: 2 additions & 2 deletions python/cudf/cudf/_lib/strings/contains.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

from libc.stdint cimport uint32_t
from libcpp.memory cimport unique_ptr
Expand All @@ -12,8 +12,8 @@ from cudf._lib.cpp.strings.contains cimport (
contains_re as cpp_contains_re,
count_re as cpp_count_re,
matches_re as cpp_matches_re,
regex_flags as regex_flags,
)
from cudf._lib.cpp.strings.regex_flags cimport regex_flags
from cudf._lib.scalar cimport DeviceScalar


Expand Down
10 changes: 7 additions & 3 deletions python/cudf/cudf/_lib/strings/extract.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

from libc.stdint cimport uint32_t
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.utility cimport move
Expand All @@ -8,12 +9,13 @@ from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.strings.extract cimport extract as cpp_extract
from cudf._lib.cpp.strings.regex_flags cimport regex_flags
from cudf._lib.cpp.table.table cimport table
from cudf._lib.scalar cimport DeviceScalar
from cudf._lib.utils cimport data_from_unique_ptr


def extract(Column source_strings, object pattern):
def extract(Column source_strings, object pattern, uint32_t flags):
"""
Returns data which contains extracted capture groups provided in
`pattern` for all `source_strings`.
Expand All @@ -24,11 +26,13 @@ def extract(Column source_strings, object pattern):
cdef column_view source_view = source_strings.view()

cdef string pattern_string = <string>str(pattern).encode()
cdef regex_flags c_flags = <regex_flags>flags

with nogil:
c_result = move(cpp_extract(
source_view,
pattern_string
pattern_string,
c_flags
))

return data_from_unique_ptr(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/strings/findall.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.scalar.scalar cimport string_scalar
from cudf._lib.cpp.strings.contains cimport regex_flags
from cudf._lib.cpp.strings.findall cimport (
findall as cpp_findall,
findall_record as cpp_findall_record,
)
from cudf._lib.cpp.strings.regex_flags cimport regex_flags
from cudf._lib.cpp.table.table cimport table
from cudf._lib.scalar cimport DeviceScalar
from cudf._lib.utils cimport data_from_unique_ptr
Expand Down
26 changes: 18 additions & 8 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ def extract(
----------
pat : str
Regular expression pattern with capturing groups.
flags : int, default 0 (no flags)
Flags to pass through to the regex engine (e.g. re.MULTILINE)
expand : bool, default True
If True, return DataFrame with one column per capture group.
If False, return a Series/Index if there is one capture group or
Expand All @@ -588,8 +590,8 @@ def extract(

Notes
-----
The `flags` parameter is not yet supported and will raise a
NotImplementedError if anything other than the default value is passed.
The `flags` parameter currently only supports re.DOTALL and
re.MULTILINE.

Examples
--------
Expand Down Expand Up @@ -618,10 +620,12 @@ def extract(
2 <NA>
dtype: object
""" # noqa W605
if flags != 0:
raise NotImplementedError("`flags` parameter is not yet supported")
if not _is_supported_regex_flags(flags):
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

data, index = libstrings.extract(self._column, pat)
data, index = libstrings.extract(self._column, pat, flags)
if len(data) == 1 and expand is False:
data = next(iter(data.values()))
else:
Expand Down Expand Up @@ -752,7 +756,9 @@ def contains(
flags = pat.flags & ~re.U
pat = pat.pattern
if not _is_supported_regex_flags(flags):
raise ValueError("invalid `flags` parameter value")
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

if pat is None:
result_col = column.column_empty(
Expand Down Expand Up @@ -3393,7 +3399,9 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex:
flags = pat.flags & ~re.U
pat = pat.pattern
if not _is_supported_regex_flags(flags):
raise ValueError("invalid `flags` parameter value")
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

return self._return_or_inplace(
libstrings.count_re(self._column, pat, flags)
Expand Down Expand Up @@ -3969,7 +3977,9 @@ def match(
flags = pat.flags & ~re.U
pat = pat.pattern
if not _is_supported_regex_flags(flags):
raise ValueError("invalid `flags` parameter value")
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

return self._return_or_inplace(
libstrings.match_re(self._column, pat, flags)
Expand Down
Loading