Skip to content

Commit

Permalink
Support replace in strings_udf (#12207)
Browse files Browse the repository at this point in the history
This PR adds support for the following function in `strings_udf`:

- `str.replace`

Part of #9639

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Ashwin Srinath (https://github.com/shwina)

URL: #12207
  • Loading branch information
brandon-b-miller authored Nov 30, 2022
1 parent f4bb574 commit 5f83a84
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 14 deletions.
35 changes: 35 additions & 0 deletions python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
len_impl,
lower_impl,
lstrip_impl,
replace_impl,
rfind_impl,
rstrip_impl,
startswith_impl,
Expand All @@ -50,6 +51,40 @@ def masked_len_impl(context, builder, sig, args):
return ret._getvalue()


def _masked_proxies(context, builder, maskedty, *args):
return tuple(
cgutils.create_struct_proxy(maskedty)(context, builder, value=arg)
for arg in args
)


@cuda_lower(
"MaskedType.replace",
MaskedType(string_view),
MaskedType(string_view),
MaskedType(string_view),
)
def masked_string_view_replace_impl(context, builder, sig, args):
ret = cgutils.create_struct_proxy(sig.return_type)(context, builder)
src_masked, to_replace_masked, replacement_masked = _masked_proxies(
context, builder, MaskedType(string_view), *args
)
result = replace_impl(
context,
builder,
nb_signature(udf_string, string_view, string_view, string_view),
(src_masked.value, to_replace_masked.value, replacement_masked.value),
)

ret.value = result
ret.valid = builder.and_(
builder.and_(src_masked.valid, to_replace_masked.valid),
replacement_masked.valid,
)

return ret._getvalue()


def create_binary_string_func(op, cuda_func, retty):
"""
Provide a wrapper around numba's low-level extension API which
Expand Down
17 changes: 17 additions & 0 deletions python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,26 @@ def generic(self, args, kws):
)


class MaskedStringViewReplace(AbstractTemplate):
key = "MaskedType.replace"

def generic(self, args, kws):
return nb_signature(
MaskedType(udf_string),
MaskedType(string_view),
MaskedType(string_view),
recvr=self.this,
)


class MaskedStringViewAttrs(AttributeTemplate):
key = MaskedType(string_view)

def resolve_replace(self, mod):
return types.BoundFunction(
MaskedStringViewReplace, MaskedType(string_view)
)

def resolve_count(self, mod):
return types.BoundFunction(
MaskedStringViewCount, MaskedType(string_view)
Expand Down
10 changes: 10 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,16 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
@pytest.mark.parametrize("to_replace", ["a", "1", "", "@"])
@pytest.mark.parametrize("replacement", ["a", "1", "", "@"])
def test_string_udf_replace(str_udf_data, to_replace, replacement):
def func(row):
return row["str_col"].replace(to_replace, replacement)

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@pytest.mark.parametrize(
"data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]]
)
Expand Down
14 changes: 14 additions & 0 deletions python/strings_udf/cpp/src/strings/udf/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cudf/strings/udf/case.cuh>
#include <cudf/strings/udf/char_types.cuh>
#include <cudf/strings/udf/replace.cuh>
#include <cudf/strings/udf/search.cuh>
#include <cudf/strings/udf/starts_with.cuh>
#include <cudf/strings/udf/strip.cuh>
Expand Down Expand Up @@ -329,3 +330,16 @@ extern "C" __device__ int concat(int* nb_retval, void* udf_str, void* const* lhs
*udf_str_ptr = result;
return 0;
}

extern "C" __device__ int replace(
int* nb_retval, void* udf_str, void* const src, void* const to_replace, void* const replacement)
{
auto src_ptr = reinterpret_cast<cudf::string_view const*>(src);
auto to_replace_ptr = reinterpret_cast<cudf::string_view const*>(to_replace);
auto replacement_ptr = reinterpret_cast<cudf::string_view const*>(replacement);

auto udf_str_ptr = new (udf_str) udf_string;
*udf_str_ptr = replace(*src_ptr, *to_replace_ptr, *replacement_ptr);

return 0;
}
39 changes: 25 additions & 14 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,6 @@ def generic(self, args, kws):
cuda_decl_registry.register_global(op)(StringViewBinaryOp)


register_stringview_binaryop(operator.eq, types.boolean)
register_stringview_binaryop(operator.ne, types.boolean)
register_stringview_binaryop(operator.lt, types.boolean)
register_stringview_binaryop(operator.gt, types.boolean)
register_stringview_binaryop(operator.le, types.boolean)
register_stringview_binaryop(operator.ge, types.boolean)

# st in other
register_stringview_binaryop(operator.contains, types.boolean)

# st + other
register_stringview_binaryop(operator.add, udf_string)


def create_binary_attr(attrname, retty):
"""
Helper function wrapping numba's low level extension API. Provides
Expand Down Expand Up @@ -212,13 +198,25 @@ def generic(self, args, kws):
return nb_signature(size_type, string_view, recvr=self.this)


class StringViewReplace(AbstractTemplate):
key = "StringView.replace"

def generic(self, args, kws):
return nb_signature(
udf_string, string_view, string_view, recvr=self.this
)


@cuda_decl_registry.register_attr
class StringViewAttrs(AttributeTemplate):
key = string_view

def resolve_count(self, mod):
return types.BoundFunction(StringViewCount, string_view)

def resolve_replace(self, mod):
return types.BoundFunction(StringViewReplace, string_view)


# Build attributes for `MaskedType(string_view)`
bool_binary_funcs = ["startswith", "endswith"]
Expand Down Expand Up @@ -272,3 +270,16 @@ def resolve_count(self, mod):
)

cuda_decl_registry.register_attr(StringViewAttrs)

register_stringview_binaryop(operator.eq, types.boolean)
register_stringview_binaryop(operator.ne, types.boolean)
register_stringview_binaryop(operator.lt, types.boolean)
register_stringview_binaryop(operator.gt, types.boolean)
register_stringview_binaryop(operator.le, types.boolean)
register_stringview_binaryop(operator.ge, types.boolean)

# st in other
register_stringview_binaryop(operator.contains, types.boolean)

# st + other
register_stringview_binaryop(operator.add, udf_string)
36 changes: 36 additions & 0 deletions python/strings_udf/strings_udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
"concat", types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR)
)

_string_view_replace = cuda.declare_device(
"replace",
types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR),
)


def _declare_binary_func(lhs, rhs, out, name):
# Declare a binary function
Expand Down Expand Up @@ -209,6 +214,37 @@ def concat_impl(context, builder, sig, args):
return result._getvalue()


def call_string_view_replace(result, src, to_replace, replacement):
return _string_view_replace(result, src, to_replace, replacement)


@cuda_lower("StringView.replace", string_view, string_view, string_view)
def replace_impl(context, builder, sig, args):
src_ptr = builder.alloca(args[0].type)
to_replace_ptr = builder.alloca(args[1].type)
replacement_ptr = builder.alloca(args[2].type)

builder.store(args[0], src_ptr)
builder.store(args[1], to_replace_ptr),
builder.store(args[2], replacement_ptr)

udf_str_ptr = builder.alloca(default_manager[udf_string].get_value_type())

_ = context.compile_internal(
builder,
call_string_view_replace,
types.void(
_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR
),
(udf_str_ptr, src_ptr, to_replace_ptr, replacement_ptr),
)

result = cgutils.create_struct_proxy(udf_string)(
context, builder, value=builder.load(udf_str_ptr)
)
return result._getvalue()


def create_binary_string_func(binary_func, retty):
"""
Provide a wrapper around numba's low-level extension API which
Expand Down
9 changes: 9 additions & 0 deletions python/strings_udf/strings_udf/tests/test_string_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,12 @@ def func(st):
return concat_char + st

run_udf_test(data, func, "str")


@pytest.mark.parametrize("to_replace", ["a", "1", "", "@"])
@pytest.mark.parametrize("replacement", ["a", "1", "", "@"])
def test_string_udf_replace(data, to_replace, replacement):
def func(st):
return st.replace(to_replace, replacement)

run_udf_test(data, func, "str")

0 comments on commit 5f83a84

Please sign in to comment.