diff --git a/python/cudf/cudf/core/udf/__init__.py b/python/cudf/cudf/core/udf/__init__.py index 926d2ea6cbf..8421d763167 100644 --- a/python/cudf/cudf/core/udf/__init__.py +++ b/python/cudf/cudf/core/udf/__init__.py @@ -32,7 +32,11 @@ column_from_udf_string_array, column_to_string_view_array, ) - from strings_udf._typing import str_view_arg_handler, string_view + from strings_udf._typing import ( + str_view_arg_handler, + string_view, + udf_string, + ) from . import strings_typing # isort: skip from . import strings_lowering # isort: skip @@ -41,7 +45,7 @@ masked_lowering.masked_constructor ) utils.JIT_SUPPORTED_TYPES |= STRING_TYPES - _supported_masked_types |= {string_view} + _supported_masked_types |= {string_view, udf_string} utils.launch_arg_getters[cudf_str_dtype] = column_to_string_view_array utils.output_col_getters[cudf_str_dtype] = column_from_udf_string_array @@ -49,6 +53,9 @@ row_function.itemsizes[cudf_str_dtype] = string_view.size_bytes utils.arg_handlers.append(str_view_arg_handler) + + masked_typing.MASKED_INIT_MAP[udf_string] = udf_string + _STRING_UDFS_ENABLED = True except ImportError as e: diff --git a/python/cudf/cudf/core/udf/strings_lowering.py b/python/cudf/cudf/core/udf/strings_lowering.py index 59041977f87..fdfd013bad7 100644 --- a/python/cudf/cudf/core/udf/strings_lowering.py +++ b/python/cudf/cudf/core/udf/strings_lowering.py @@ -7,7 +7,7 @@ from numba.core.typing import signature as nb_signature from numba.cuda.cudaimpl import lower as cuda_lower -from strings_udf._typing import size_type, string_view +from strings_udf._typing import size_type, string_view, udf_string from strings_udf.lowering import ( contains_impl, count_impl, @@ -22,8 +22,11 @@ istitle_impl, isupper_impl, len_impl, + lstrip_impl, rfind_impl, + rstrip_impl, startswith_impl, + strip_impl, ) from cudf.core.udf.masked_typing import MaskedType @@ -79,6 +82,13 @@ def masked_binary_func_impl(context, builder, sig, args): ) +create_binary_string_func("MaskedType.strip", strip_impl, udf_string) + +create_binary_string_func("MaskedType.lstrip", lstrip_impl, udf_string) + +create_binary_string_func("MaskedType.rstrip", rstrip_impl, udf_string) + + create_binary_string_func( "MaskedType.startswith", startswith_impl, diff --git a/python/cudf/cudf/core/udf/strings_typing.py b/python/cudf/cudf/core/udf/strings_typing.py index 1179688651f..f8f50600b12 100644 --- a/python/cudf/cudf/core/udf/strings_typing.py +++ b/python/cudf/cudf/core/udf/strings_typing.py @@ -13,7 +13,9 @@ id_unary_funcs, int_binary_funcs, size_type, + string_return_attrs, string_view, + udf_string, ) from cudf.core.udf import masked_typing @@ -172,6 +174,13 @@ def resolve_valid(self, mod): create_masked_binary_attr(f"MaskedType.{func}", size_type), ) +for func in string_return_attrs: + setattr( + MaskedStringViewAttrs, + f"resolve_{func}", + create_masked_binary_attr(f"MaskedType.{func}", udf_string), + ) + for func in id_unary_funcs: setattr( MaskedStringViewAttrs, diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index b4c7cef3a4c..7af47f981d6 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -876,6 +876,33 @@ def func(row): run_masked_udf_test(func, str_udf_data, check_dtype=False) +@string_udf_test +@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_strip(str_udf_data, strip_char): + def func(row): + return row["str_col"].strip(strip_char) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_lstrip(str_udf_data, strip_char): + def func(row): + return row["str_col"].lstrip(strip_char) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_rstrip(str_udf_data, strip_char): + def func(row): + return row["str_col"].rstrip(strip_char) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + @pytest.mark.parametrize( "data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]] ) diff --git a/python/strings_udf/cpp/src/strings/udf/shim.cu b/python/strings_udf/cpp/src/strings/udf/shim.cu index b284d58fe58..63e740c5226 100644 --- a/python/strings_udf/cpp/src/strings/udf/shim.cu +++ b/python/strings_udf/cpp/src/strings/udf/shim.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include using namespace cudf::strings::udf; @@ -227,3 +228,45 @@ extern "C" __device__ int udf_string_from_string_view(int* nb_retbal, return 0; } + +extern "C" __device__ int strip(int* nb_retval, + void* udf_str, + void* const* to_strip, + void* const* strip_str) +{ + auto to_strip_ptr = reinterpret_cast(to_strip); + auto strip_str_ptr = reinterpret_cast(strip_str); + auto udf_str_ptr = reinterpret_cast(udf_str); + + *udf_str_ptr = strip(*to_strip_ptr, *strip_str_ptr); + + return 0; +} + +extern "C" __device__ int lstrip(int* nb_retval, + void* udf_str, + void* const* to_strip, + void* const* strip_str) +{ + auto to_strip_ptr = reinterpret_cast(to_strip); + auto strip_str_ptr = reinterpret_cast(strip_str); + auto udf_str_ptr = reinterpret_cast(udf_str); + + *udf_str_ptr = strip(*to_strip_ptr, *strip_str_ptr, cudf::strings::side_type::LEFT); + + return 0; +} + +extern "C" __device__ int rstrip(int* nb_retval, + void* udf_str, + void* const* to_strip, + void* const* strip_str) +{ + auto to_strip_ptr = reinterpret_cast(to_strip); + auto strip_str_ptr = reinterpret_cast(strip_str); + auto udf_str_ptr = reinterpret_cast(udf_str); + + *udf_str_ptr = strip(*to_strip_ptr, *strip_str_ptr, cudf::strings::side_type::RIGHT); + + return 0; +} diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index 320958960cd..a309a9cb93c 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -181,7 +181,7 @@ def attr(self, mod): return attr -def create_identifier_attr(attrname): +def create_identifier_attr(attrname, retty): """ Helper function wrapping numba's low level extension API. Provides the boilerplate needed to register a unary function of a string @@ -192,7 +192,7 @@ class StringViewIdentifierAttr(AbstractTemplate): key = f"StringView.{attrname}" def generic(self, args, kws): - return nb_signature(types.boolean, recvr=self.this) + return nb_signature(retty, recvr=self.this) def attr(self, mod): return types.BoundFunction(StringViewIdentifierAttr, string_view) @@ -229,6 +229,7 @@ def resolve_count(self, mod): "isnumeric", "istitle", ] +string_return_attrs = ["strip", "lstrip", "rstrip"] for func in bool_binary_funcs: setattr( @@ -237,12 +238,24 @@ def resolve_count(self, mod): create_binary_attr(func, types.boolean), ) +for func in string_return_attrs: + setattr( + StringViewAttrs, + f"resolve_{func}", + create_binary_attr(func, udf_string), + ) + + for func in int_binary_funcs: setattr( StringViewAttrs, f"resolve_{func}", create_binary_attr(func, size_type) ) for func in id_unary_funcs: - setattr(StringViewAttrs, f"resolve_{func}", create_identifier_attr(func)) + setattr( + StringViewAttrs, + f"resolve_{func}", + create_identifier_attr(func, types.boolean), + ) cuda_decl_registry.register_attr(StringViewAttrs) diff --git a/python/strings_udf/strings_udf/lowering.py b/python/strings_udf/strings_udf/lowering.py index 909b0e56187..17a1869e881 100644 --- a/python/strings_udf/strings_udf/lowering.py +++ b/python/strings_udf/strings_udf/lowering.py @@ -19,6 +19,7 @@ character_flags_table_ptr = get_character_flags_table_ptr() _STR_VIEW_PTR = types.CPointer(string_view) +_UDF_STRING_PTR = types.CPointer(udf_string) # CUDA function declarations @@ -34,6 +35,12 @@ def _declare_binary_func(lhs, rhs, out, name): ) +def _declare_strip_func(name): + return cuda.declare_device( + name, size_type(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR) + ) + + # A binary function of the form f(string, string) -> bool _declare_bool_str_str_func = partial( _declare_binary_func, _STR_VIEW_PTR, _STR_VIEW_PTR, types.boolean @@ -55,6 +62,9 @@ def _declare_binary_func(lhs, rhs, out, name): _string_view_find = _declare_size_type_str_str_func("find") _string_view_rfind = _declare_size_type_str_str_func("rfind") _string_view_contains = _declare_bool_str_str_func("contains") +_string_view_strip = _declare_strip_func("strip") +_string_view_lstrip = _declare_strip_func("lstrip") +_string_view_rstrip = _declare_strip_func("rstrip") # A binary function of the form f(string, int) -> bool @@ -162,17 +172,44 @@ def deco(cuda_func): def binary_func_impl(context, builder, sig, args): lhs_ptr = builder.alloca(args[0].type) rhs_ptr = builder.alloca(args[1].type) - builder.store(args[0], lhs_ptr) builder.store(args[1], rhs_ptr) - result = context.compile_internal( - builder, - cuda_func, - nb_signature(retty, _STR_VIEW_PTR, _STR_VIEW_PTR), - (lhs_ptr, rhs_ptr), - ) - return result + # these conditional statements should compile out + if retty != udf_string: + # binary function of two strings yielding a fixed-width type + # example: str.startswith(other) -> bool + # shim functions can return the value through nb_retval + result = context.compile_internal( + builder, + cuda_func, + nb_signature(retty, _STR_VIEW_PTR, _STR_VIEW_PTR), + (lhs_ptr, rhs_ptr), + ) + return result + else: + # binary function of two strings yielding a new string + # example: str.strip(other) -> str + # shim functions can not return a struct due to C linkage + # so we create a new udf_string and pass a pointer to it + # for the shim function to write the output to. The return + # value of compile_internal is therefore discarded (although + # this may change in the future if we need to return error + # codes, for instance). + udf_str_ptr = builder.alloca( + default_manager[udf_string].get_value_type() + ) + + _ = context.compile_internal( + builder, + cuda_func, + size_type(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR), + (udf_str_ptr, lhs_ptr, rhs_ptr), + ) + result = cgutils.create_struct_proxy(udf_string)( + context, builder, value=builder.load(udf_str_ptr) + ) + return result._getvalue() return binary_func_impl @@ -214,6 +251,21 @@ def lt_impl(st, rhs): return _string_view_lt(st, rhs) +@create_binary_string_func("StringView.strip", udf_string) +def strip_impl(result, to_strip, strip_char): + return _string_view_strip(result, to_strip, strip_char) + + +@create_binary_string_func("StringView.lstrip", udf_string) +def lstrip_impl(result, to_strip, strip_char): + return _string_view_lstrip(result, to_strip, strip_char) + + +@create_binary_string_func("StringView.rstrip", udf_string) +def rstrip_impl(result, to_strip, strip_char): + return _string_view_rstrip(result, to_strip, strip_char) + + @create_binary_string_func("StringView.startswith", types.boolean) def startswith_impl(sv, substr): return _string_view_startswith(sv, substr) diff --git a/python/strings_udf/strings_udf/tests/test_string_udfs.py b/python/strings_udf/strings_udf/tests/test_string_udfs.py index ca3fbda4eb1..522433d404f 100644 --- a/python/strings_udf/strings_udf/tests/test_string_udfs.py +++ b/python/strings_udf/strings_udf/tests/test_string_udfs.py @@ -278,3 +278,27 @@ def func(st): return st run_udf_test(data, func, "str") + + +@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_strip(data, strip_char): + def func(st): + return st.strip(strip_char) + + run_udf_test(data, func, "str") + + +@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_lstrip(data, strip_char): + def func(st): + return st.lstrip(strip_char) + + run_udf_test(data, func, "str") + + +@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"]) +def test_string_udf_rstrip(data, strip_char): + def func(st): + return st.rstrip(strip_char) + + run_udf_test(data, func, "str")