Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support strip, lstrip, and rstrip in strings_udf #12091

Merged
merged 66 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
40eb1e2
Add strings udf C++ classes and function for phase II
davidwendt Oct 12, 2022
5317db8
fix style error
davidwendt Oct 12, 2022
8e531a5
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 13, 2022
b8d7868
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 13, 2022
2e36b6a
support returning strings within strings_udf library
brandon-b-miller Oct 13, 2022
238c862
returning strings working
brandon-b-miller Oct 14, 2022
0544c23
clean up code a bit
brandon-b-miller Oct 17, 2022
ae1bbdc
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 17, 2022
edcaaf2
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 18, 2022
a5661bc
change void* to udf_string*
davidwendt Oct 18, 2022
9661c4e
update doxygens
davidwendt Oct 18, 2022
a6f03a3
remove unnecessary explicit casting
brandon-b-miller Oct 18, 2022
ece495f
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 18, 2022
5554ed9
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 19, 2022
ebaf088
add pad utility functions
davidwendt Oct 19, 2022
c3e17ac
fix doxygen for udf_apis.hpp
davidwendt Oct 19, 2022
2dae45d
fix to_string to use count_digits
davidwendt Oct 20, 2022
3467f34
add ALL_FLAGS
davidwendt Oct 20, 2022
4f63c54
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 20, 2022
7639039
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 21, 2022
84721d4
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 24, 2022
4c72149
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 25, 2022
cf72fc8
add noexcept decl to appropriate member functions
davidwendt Oct 25, 2022
28e917b
fix return types for split
davidwendt Oct 25, 2022
f82c454
fix doxygen for various functions
davidwendt Oct 25, 2022
3b513a3
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 26, 2022
7b9718c
create free_udf_strings_array function
davidwendt Oct 31, 2022
68e54e8
fix compare returns, null assignment, reuse ctors
davidwendt Oct 31, 2022
6eef0a4
fix some doxygen wording
davidwendt Oct 31, 2022
02aa5b4
Merge branch 'branch-22.12' into udf-string-class
davidwendt Oct 31, 2022
69e0d7c
remove string_view const parameter decl
davidwendt Oct 31, 2022
a95c030
fix default-stream
davidwendt Oct 31, 2022
e0526e6
remove lstrip and rstrip
davidwendt Oct 31, 2022
bc903d6
reword split doxygen text for result=nullptr
davidwendt Oct 31, 2022
229c1f2
Merge branch 'branch-22.12' into udf-string-class
davidwendt Nov 1, 2022
eb6532e
add cuda_runtime.h to resolve device refs
davidwendt Nov 1, 2022
a8fca12
fix doxygen wording for pad()
davidwendt Nov 1, 2022
a249d13
refactor split; add count_tokens function
davidwendt Nov 1, 2022
96b06f6
refactor append, replace for better reuse
davidwendt Nov 1, 2022
7849307
expand spos/epos var names
davidwendt Nov 1, 2022
cadcf79
add more doc to replace() for count parm
davidwendt Nov 1, 2022
b3a43b8
Merge branch 'branch-22.12' into udf-string-class
davidwendt Nov 1, 2022
e0d1374
Merge remote-tracking branch 'david/udf-string-class' into fea-string…
brandon-b-miller Nov 1, 2022
1e02c26
adjust for changes
brandon-b-miller Nov 1, 2022
c9ef3ec
Merge branch 'branch-22.12' into fea-strings-udf-return-strings
brandon-b-miller Nov 2, 2022
1218c08
fix up cython
brandon-b-miller Nov 2, 2022
b9aabdd
merge the latest, resolve conflicts, pass tests
brandon-b-miller Nov 3, 2022
e864dea
from_udf_string_array -> column_from_udf_string_array, to_string_view…
brandon-b-miller Nov 3, 2022
9fccc9b
refactor
brandon-b-miller Nov 3, 2022
d5c37a8
prune imports
brandon-b-miller Nov 3, 2022
b7c1b1d
cleanup
brandon-b-miller Nov 3, 2022
267b904
begin to address reviews
brandon-b-miller Nov 4, 2022
8b7a412
Update python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx
brandon-b-miller Nov 4, 2022
4f821ca
finish addressing reviews, walrus everywhere!
brandon-b-miller Nov 4, 2022
b0a8681
support strip
brandon-b-miller Nov 7, 2022
18aee5a
updates
brandon-b-miller Nov 8, 2022
2cefbe4
merge 22.12
brandon-b-miller Nov 8, 2022
d7556b0
fix bad merge
brandon-b-miller Nov 8, 2022
c4f8847
add tests to cudf
brandon-b-miller Nov 8, 2022
7030108
plumb to maskedtype
brandon-b-miller Nov 8, 2022
11e966c
cleanup
brandon-b-miller Nov 8, 2022
9991c76
more cleanup
brandon-b-miller Nov 8, 2022
837a49c
Update python/strings_udf/strings_udf/lowering.py
brandon-b-miller Nov 9, 2022
302fe60
address reviews
brandon-b-miller Nov 9, 2022
02167d3
fix copypaste bug
brandon-b-miller Nov 9, 2022
3539ed7
small refactor
brandon-b-miller Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/cudf/cudf/core/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
column_from_udf_string_array,
column_to_string_view_array,
)
from strings_udf._typing import str_view_arg_handler, string_view
from strings_udf._typing import (
str_view_arg_handler,
string_view,
udf_string,
)

from . import strings_typing # isort: skip
from . import strings_lowering # isort: skip
Expand All @@ -41,14 +45,17 @@
masked_lowering.masked_constructor
)
utils.JIT_SUPPORTED_TYPES |= STRING_TYPES
_supported_masked_types |= {string_view}
_supported_masked_types |= {string_view, udf_string}

utils.launch_arg_getters[cudf_str_dtype] = column_to_string_view_array
utils.output_col_getters[cudf_str_dtype] = column_from_udf_string_array
utils.masked_array_types[cudf_str_dtype] = string_view
row_function.itemsizes[cudf_str_dtype] = string_view.size_bytes

utils.arg_handlers.append(str_view_arg_handler)

masked_typing.MASKED_INIT_MAP[udf_string] = udf_string

_STRING_UDFS_ENABLED = True

except ImportError as e:
Expand Down
12 changes: 11 additions & 1 deletion python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numba.core.typing import signature as nb_signature
from numba.cuda.cudaimpl import lower as cuda_lower

from strings_udf._typing import size_type, string_view
from strings_udf._typing import size_type, string_view, udf_string
from strings_udf.lowering import (
contains_impl,
count_impl,
Expand All @@ -22,8 +22,11 @@
istitle_impl,
isupper_impl,
len_impl,
lstrip_impl,
rfind_impl,
rstrip_impl,
startswith_impl,
strip_impl,
)

from cudf.core.udf.masked_typing import MaskedType
Expand Down Expand Up @@ -79,6 +82,13 @@ def masked_binary_func_impl(context, builder, sig, args):
)


create_binary_string_func("MaskedType.strip", strip_impl, udf_string)

create_binary_string_func("MaskedType.lstrip", lstrip_impl, udf_string)

create_binary_string_func("MaskedType.rstrip", rstrip_impl, udf_string)


create_binary_string_func(
"MaskedType.startswith",
startswith_impl,
Expand Down
9 changes: 9 additions & 0 deletions python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
id_unary_funcs,
int_binary_funcs,
size_type,
string_binary_funcs,
string_view,
udf_string,
)

from cudf.core.udf import masked_typing
Expand Down Expand Up @@ -172,6 +174,13 @@ def resolve_valid(self, mod):
create_masked_binary_attr(f"MaskedType.{func}", size_type),
)

for func in string_binary_funcs:
setattr(
MaskedStringViewAttrs,
f"resolve_{func}",
create_masked_binary_attr(f"MaskedType.{func}", udf_string),
)

for func in id_unary_funcs:
setattr(
MaskedStringViewAttrs,
Expand Down
24 changes: 24 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,30 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_strip(str_udf_data, strip_char):
def func(row):
return row["str_col"].strip(strip_char)

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_lstrip(str_udf_data, strip_char):
def func(row):
return row["str_col"].lstrip(strip_char)

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_rstrip(str_udf_data, strip_char):
def func(row):
return row["str_col"].rstrip(strip_char)

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@pytest.mark.parametrize(
"data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]]
)
Expand Down
43 changes: 43 additions & 0 deletions python/strings_udf/cpp/src/strings/udf/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cudf/strings/udf/char_types.cuh>
#include <cudf/strings/udf/search.cuh>
#include <cudf/strings/udf/starts_with.cuh>
#include <cudf/strings/udf/strip.cuh>
#include <cudf/strings/udf/udf_string.cuh>

using namespace cudf::strings::udf;
Expand Down Expand Up @@ -227,3 +228,45 @@ extern "C" __device__ int udf_string_from_string_view(int* nb_retbal,

return 0;
}

extern "C" __device__ int strip(int* nb_retval,
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
void* udf_str,
void* const* to_strip,
void* const* strip_str)
{
auto to_strip_ptr = reinterpret_cast<cudf::string_view const*>(to_strip);
auto strip_str_ptr = reinterpret_cast<cudf::string_view const*>(strip_str);
auto udf_str_ptr = reinterpret_cast<udf_string*>(udf_str);

*udf_str_ptr = strip(*to_strip_ptr, *strip_str_ptr);

return 0;
}

extern "C" __device__ int lstrip(int* nb_retval,
void* udf_str,
void* const* to_strip,
void* const* strip_str)
{
auto to_strip_ptr = reinterpret_cast<cudf::string_view const*>(to_strip);
auto strip_str_ptr = reinterpret_cast<cudf::string_view const*>(strip_str);
auto udf_str_ptr = reinterpret_cast<udf_string*>(udf_str);

*udf_str_ptr = strip(*to_strip_ptr, *strip_str_ptr, cudf::strings::side_type::LEFT);

return 0;
}

extern "C" __device__ int rstrip(int* nb_retval,
void* udf_str,
void* const* to_strip,
void* const* strip_str)
{
auto to_strip_ptr = reinterpret_cast<cudf::string_view const*>(to_strip);
auto strip_str_ptr = reinterpret_cast<cudf::string_view const*>(strip_str);
auto udf_str_ptr = reinterpret_cast<udf_string*>(udf_str);

*udf_str_ptr = strip(*to_strip_ptr, *strip_str_ptr, cudf::strings::side_type::RIGHT);

return 0;
}
19 changes: 16 additions & 3 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def attr(self, mod):
return attr


def create_identifier_attr(attrname):
def create_identifier_attr(attrname, retty):
"""
Helper function wrapping numba's low level extension API. Provides
the boilerplate needed to register a unary function of a string
Expand All @@ -192,7 +192,7 @@ class StringViewIdentifierAttr(AbstractTemplate):
key = f"StringView.{attrname}"

def generic(self, args, kws):
return nb_signature(types.boolean, recvr=self.this)
return nb_signature(retty, recvr=self.this)

def attr(self, mod):
return types.BoundFunction(StringViewIdentifierAttr, string_view)
Expand Down Expand Up @@ -229,6 +229,7 @@ def resolve_count(self, mod):
"isnumeric",
"istitle",
]
string_binary_funcs = ["strip", "lstrip", "rstrip"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is somewhat misleading. There are other string binary functions that we already have implemented, including operators and things like find and contains. Is this list meant to contain binary string functions that also return a string, or is it even more specific than that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming this to string_return_attrs


for func in bool_binary_funcs:
setattr(
Expand All @@ -237,12 +238,24 @@ def resolve_count(self, mod):
create_binary_attr(func, types.boolean),
)

for func in string_binary_funcs:
setattr(
StringViewAttrs,
f"resolve_{func}",
create_binary_attr(func, udf_string),
)


for func in int_binary_funcs:
setattr(
StringViewAttrs, f"resolve_{func}", create_binary_attr(func, size_type)
)

for func in id_unary_funcs:
setattr(StringViewAttrs, f"resolve_{func}", create_identifier_attr(func))
setattr(
StringViewAttrs,
f"resolve_{func}",
create_identifier_attr(func, types.boolean),
)

cuda_decl_registry.register_attr(StringViewAttrs)
65 changes: 56 additions & 9 deletions python/strings_udf/strings_udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
character_flags_table_ptr = get_character_flags_table_ptr()

_STR_VIEW_PTR = types.CPointer(string_view)
_UDF_STRING_PTR = types.CPointer(udf_string)


# CUDA function declarations
Expand Down Expand Up @@ -55,7 +56,15 @@ def _declare_binary_func(lhs, rhs, out, name):
_string_view_find = _declare_size_type_str_str_func("find")
_string_view_rfind = _declare_size_type_str_str_func("rfind")
_string_view_contains = _declare_bool_str_str_func("contains")

_string_view_strip = cuda.declare_device(
"strip", types.int32(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR)
)
_string_view_lstrip = cuda.declare_device(
"strip", types.int32(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR)
)
_string_view_rstrip = cuda.declare_device(
"strip", types.int32(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR)
)

# A binary function of the form f(string, int) -> bool
_declare_bool_str_int_func = partial(
Expand Down Expand Up @@ -162,17 +171,40 @@ def deco(cuda_func):
def binary_func_impl(context, builder, sig, args):
lhs_ptr = builder.alloca(args[0].type)
rhs_ptr = builder.alloca(args[1].type)

builder.store(args[0], lhs_ptr)
builder.store(args[1], rhs_ptr)
result = context.compile_internal(
builder,
cuda_func,
nb_signature(retty, _STR_VIEW_PTR, _STR_VIEW_PTR),
(lhs_ptr, rhs_ptr),
)

return result
# these conditional statements should compile out
if retty != udf_string:
# binary function of two strings yielding a fixed-width type
# example: str.startswith(other) -> bool
# shim functions can return the value through nb_retval
result = context.compile_internal(
builder,
cuda_func,
nb_signature(retty, _STR_VIEW_PTR, _STR_VIEW_PTR),
(lhs_ptr, rhs_ptr),
)
return result
else:
# binary function of two strings yielding a new string
# example: str.strip(other) -> str
# shim functions can not return a struct due to C linkage
# so we operate on an extra void ptr and throw away nb_retval
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
udf_str_ptr = builder.alloca(
default_manager[udf_string].get_value_type()
)

_ = context.compile_internal(
builder,
cuda_func,
size_type(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR),
(udf_str_ptr, lhs_ptr, rhs_ptr),
)
result = cgutils.create_struct_proxy(udf_string)(
context, builder, value=builder.load(udf_str_ptr)
)
return result._getvalue()

return binary_func_impl

Expand Down Expand Up @@ -214,6 +246,21 @@ def lt_impl(st, rhs):
return _string_view_lt(st, rhs)


@create_binary_string_func("StringView.strip", udf_string)
def strip_impl(result, to_strip, strip_char):
return _string_view_strip(result, to_strip, strip_char)


@create_binary_string_func("StringView.lstrip", udf_string)
def lstrip_impl(result, to_strip, strip_char):
return _string_view_lstrip(result, to_strip, strip_char)


@create_binary_string_func("StringView.rstrip", udf_string)
def rstrip_impl(result, to_strip, strip_char):
return _string_view_rstrip(result, to_strip, strip_char)


@create_binary_string_func("StringView.startswith", types.boolean)
def startswith_impl(sv, substr):
return _string_view_startswith(sv, substr)
Expand Down
24 changes: 24 additions & 0 deletions python/strings_udf/strings_udf/tests/test_string_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,27 @@ def func(st):
return st

run_udf_test(data, func, "str")


@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_strip(data, strip_char):
def func(st):
return st.strip(strip_char)

run_udf_test(data, func, "str")


@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_lstrip(data, strip_char):
def func(st):
return st.lstrip(strip_char)

run_udf_test(data, func, "str")


@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_rstrip(data, strip_char):
def func(st):
return st.rstrip(strip_char)

run_udf_test(data, func, "str")