diff --git a/python/cudf/cudf/core/udf/strings_lowering.py b/python/cudf/cudf/core/udf/strings_lowering.py index 465866cdd55..ec956cdd65d 100644 --- a/python/cudf/cudf/core/udf/strings_lowering.py +++ b/python/cudf/cudf/core/udf/strings_lowering.py @@ -24,6 +24,7 @@ len_impl, lower_impl, lstrip_impl, + replace_impl, rfind_impl, rstrip_impl, startswith_impl, @@ -50,6 +51,40 @@ def masked_len_impl(context, builder, sig, args): return ret._getvalue() +def _masked_proxies(context, builder, maskedty, *args): + return tuple( + cgutils.create_struct_proxy(maskedty)(context, builder, value=arg) + for arg in args + ) + + +@cuda_lower( + "MaskedType.replace", + MaskedType(string_view), + MaskedType(string_view), + MaskedType(string_view), +) +def masked_string_view_replace_impl(context, builder, sig, args): + ret = cgutils.create_struct_proxy(sig.return_type)(context, builder) + src_masked, to_replace_masked, replacement_masked = _masked_proxies( + context, builder, MaskedType(string_view), *args + ) + result = replace_impl( + context, + builder, + nb_signature(udf_string, string_view, string_view, string_view), + (src_masked.value, to_replace_masked.value, replacement_masked.value), + ) + + ret.value = result + ret.valid = builder.and_( + builder.and_(src_masked.valid, to_replace_masked.valid), + replacement_masked.valid, + ) + + return ret._getvalue() + + def create_binary_string_func(op, cuda_func, retty): """ Provide a wrapper around numba's low-level extension API which diff --git a/python/cudf/cudf/core/udf/strings_typing.py b/python/cudf/cudf/core/udf/strings_typing.py index 87500cba564..e373b8b018d 100644 --- a/python/cudf/cudf/core/udf/strings_typing.py +++ b/python/cudf/cudf/core/udf/strings_typing.py @@ -155,9 +155,26 @@ def generic(self, args, kws): ) +class MaskedStringViewReplace(AbstractTemplate): + key = "MaskedType.replace" + + def generic(self, args, kws): + return nb_signature( + MaskedType(udf_string), + MaskedType(string_view), + MaskedType(string_view), + recvr=self.this, + ) + + class MaskedStringViewAttrs(AttributeTemplate): key = MaskedType(string_view) + def resolve_replace(self, mod): + return types.BoundFunction( + MaskedStringViewReplace, MaskedType(string_view) + ) + def resolve_count(self, mod): return types.BoundFunction( MaskedStringViewCount, MaskedType(string_view) diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index 72abc8e9f87..e3b7e62433e 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -928,6 +928,16 @@ def func(row): run_masked_udf_test(func, str_udf_data, check_dtype=False) +@string_udf_test +@pytest.mark.parametrize("to_replace", ["a", "1", "", "@"]) +@pytest.mark.parametrize("replacement", ["a", "1", "", "@"]) +def test_string_udf_replace(str_udf_data, to_replace, replacement): + def func(row): + return row["str_col"].replace(to_replace, replacement) + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + @pytest.mark.parametrize( "data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]] ) diff --git a/python/strings_udf/cpp/src/strings/udf/shim.cu b/python/strings_udf/cpp/src/strings/udf/shim.cu index c5a446c9518..d10cc635209 100644 --- a/python/strings_udf/cpp/src/strings/udf/shim.cu +++ b/python/strings_udf/cpp/src/strings/udf/shim.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -329,3 +330,16 @@ extern "C" __device__ int concat(int* nb_retval, void* udf_str, void* const* lhs *udf_str_ptr = result; return 0; } + +extern "C" __device__ int replace( + int* nb_retval, void* udf_str, void* const src, void* const to_replace, void* const replacement) +{ + auto src_ptr = reinterpret_cast(src); + auto to_replace_ptr = reinterpret_cast(to_replace); + auto replacement_ptr = reinterpret_cast(replacement); + + auto udf_str_ptr = new (udf_str) udf_string; + *udf_str_ptr = replace(*src_ptr, *to_replace_ptr, *replacement_ptr); + + return 0; +} diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index 3fadf030ce9..99e4046b0b3 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -153,20 +153,6 @@ def generic(self, args, kws): cuda_decl_registry.register_global(op)(StringViewBinaryOp) -register_stringview_binaryop(operator.eq, types.boolean) -register_stringview_binaryop(operator.ne, types.boolean) -register_stringview_binaryop(operator.lt, types.boolean) -register_stringview_binaryop(operator.gt, types.boolean) -register_stringview_binaryop(operator.le, types.boolean) -register_stringview_binaryop(operator.ge, types.boolean) - -# st in other -register_stringview_binaryop(operator.contains, types.boolean) - -# st + other -register_stringview_binaryop(operator.add, udf_string) - - def create_binary_attr(attrname, retty): """ Helper function wrapping numba's low level extension API. Provides @@ -212,6 +198,15 @@ def generic(self, args, kws): return nb_signature(size_type, string_view, recvr=self.this) +class StringViewReplace(AbstractTemplate): + key = "StringView.replace" + + def generic(self, args, kws): + return nb_signature( + udf_string, string_view, string_view, recvr=self.this + ) + + @cuda_decl_registry.register_attr class StringViewAttrs(AttributeTemplate): key = string_view @@ -219,6 +214,9 @@ class StringViewAttrs(AttributeTemplate): def resolve_count(self, mod): return types.BoundFunction(StringViewCount, string_view) + def resolve_replace(self, mod): + return types.BoundFunction(StringViewReplace, string_view) + # Build attributes for `MaskedType(string_view)` bool_binary_funcs = ["startswith", "endswith"] @@ -272,3 +270,16 @@ def resolve_count(self, mod): ) cuda_decl_registry.register_attr(StringViewAttrs) + +register_stringview_binaryop(operator.eq, types.boolean) +register_stringview_binaryop(operator.ne, types.boolean) +register_stringview_binaryop(operator.lt, types.boolean) +register_stringview_binaryop(operator.gt, types.boolean) +register_stringview_binaryop(operator.le, types.boolean) +register_stringview_binaryop(operator.ge, types.boolean) + +# st in other +register_stringview_binaryop(operator.contains, types.boolean) + +# st + other +register_stringview_binaryop(operator.add, udf_string) diff --git a/python/strings_udf/strings_udf/lowering.py b/python/strings_udf/strings_udf/lowering.py index cca3066a844..7294d06c05b 100644 --- a/python/strings_udf/strings_udf/lowering.py +++ b/python/strings_udf/strings_udf/lowering.py @@ -35,6 +35,11 @@ "concat", types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR) ) +_string_view_replace = cuda.declare_device( + "replace", + types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR), +) + def _declare_binary_func(lhs, rhs, out, name): # Declare a binary function @@ -209,6 +214,37 @@ def concat_impl(context, builder, sig, args): return result._getvalue() +def call_string_view_replace(result, src, to_replace, replacement): + return _string_view_replace(result, src, to_replace, replacement) + + +@cuda_lower("StringView.replace", string_view, string_view, string_view) +def replace_impl(context, builder, sig, args): + src_ptr = builder.alloca(args[0].type) + to_replace_ptr = builder.alloca(args[1].type) + replacement_ptr = builder.alloca(args[2].type) + + builder.store(args[0], src_ptr) + builder.store(args[1], to_replace_ptr), + builder.store(args[2], replacement_ptr) + + udf_str_ptr = builder.alloca(default_manager[udf_string].get_value_type()) + + _ = context.compile_internal( + builder, + call_string_view_replace, + types.void( + _UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR + ), + (udf_str_ptr, src_ptr, to_replace_ptr, replacement_ptr), + ) + + result = cgutils.create_struct_proxy(udf_string)( + context, builder, value=builder.load(udf_str_ptr) + ) + return result._getvalue() + + def create_binary_string_func(binary_func, retty): """ Provide a wrapper around numba's low-level extension API which diff --git a/python/strings_udf/strings_udf/tests/test_string_udfs.py b/python/strings_udf/strings_udf/tests/test_string_udfs.py index 02c3a8b8c12..b8de821e101 100644 --- a/python/strings_udf/strings_udf/tests/test_string_udfs.py +++ b/python/strings_udf/strings_udf/tests/test_string_udfs.py @@ -332,3 +332,12 @@ def func(st): return concat_char + st run_udf_test(data, func, "str") + + +@pytest.mark.parametrize("to_replace", ["a", "1", "", "@"]) +@pytest.mark.parametrize("replacement", ["a", "1", "", "@"]) +def test_string_udf_replace(data, to_replace, replacement): + def func(st): + return st.replace(to_replace, replacement) + + run_udf_test(data, func, "str")