From 254a583a58a38ef574a3d2114a1586c7acb6cc3f Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Wed, 1 Mar 2023 07:12:18 -0800 Subject: [PATCH] address reviews --- python/cudf/cudf/core/udf/strings_lowering.py | 3 +- python/cudf/cudf/core/udf/strings_typing.py | 89 +++++++++++-------- python/cudf/cudf/tests/test_string_udfs.py | 6 +- 3 files changed, 55 insertions(+), 43 deletions(-) diff --git a/python/cudf/cudf/core/udf/strings_lowering.py b/python/cudf/cudf/core/udf/strings_lowering.py index 6086f3a2069..a53722f505d 100644 --- a/python/cudf/cudf/core/udf/strings_lowering.py +++ b/python/cudf/cudf/core/udf/strings_lowering.py @@ -162,7 +162,7 @@ def cast_udf_string_to_string_view(context, builder, fromty, toty, val): sv_ptr = builder.alloca(default_manager[toty].get_value_type()) builder.store(val, udf_str_ptr) - _ = context.compile_internal( + context.compile_internal( builder, call_create_string_view_from_udf_string, nb_signature(types.void, _UDF_STRING_PTR, _STR_VIEW_PTR), @@ -201,7 +201,6 @@ def call_len_string_view(st): @cuda_lower(len, string_view) -@cuda_lower(len, udf_string) def len_impl(context, builder, sig, args): sv_ptr = builder.alloca(args[0].type) builder.store(args[0], sv_ptr) diff --git a/python/cudf/cudf/core/udf/strings_typing.py b/python/cudf/cudf/core/udf/strings_typing.py index 8b239f4b806..50d34be40a0 100644 --- a/python/cudf/cudf/core/udf/strings_typing.py +++ b/python/cudf/cudf/core/udf/strings_typing.py @@ -15,23 +15,6 @@ size_type = types.int32 -bool_binary_funcs = ["startswith", "endswith"] -int_binary_funcs = ["find", "rfind"] -id_unary_funcs = [ - "isalpha", - "isalnum", - "isdecimal", - "isdigit", - "isupper", - "islower", - "isspace", - "isnumeric", - "istitle", -] -string_unary_funcs = ["upper", "lower"] -string_return_attrs = ["strip", "lstrip", "rstrip"] - - # String object definitions class UDFString(types.Type): @@ -217,34 +200,62 @@ def generic(self, args, kws): class StringViewAttrs(AttributeTemplate): key = string_view - resolve_startswith = create_binary_attr("startswith", types.boolean) - resolve_endswith = create_binary_attr("endswith", types.boolean) + def resolve_count(self, mod): + return types.BoundFunction(StringViewCount, string_view) + + def resolve_replace(self, mod): + return types.BoundFunction(StringViewReplace, string_view) + + +bool_binary_funcs = ["startswith", "endswith"] +int_binary_funcs = ["find", "rfind"] +id_unary_funcs = [ + "isalpha", + "isalnum", + "isdecimal", + "isdigit", + "isupper", + "islower", + "isspace", + "isnumeric", + "istitle", +] +string_unary_funcs = ["upper", "lower"] +string_return_attrs = ["strip", "lstrip", "rstrip"] - resolve_strip = create_binary_attr("strip", udf_string) - resolve_lstrip = create_binary_attr("lstrip", udf_string) - resolve_rstrip = create_binary_attr("rstrip", udf_string) +for func in bool_binary_funcs: + setattr( + StringViewAttrs, + f"resolve_{func}", + create_binary_attr(func, types.boolean), + ) - resolve_find = create_binary_attr("find", size_type) - resolve_rfind = create_binary_attr("rfind", size_type) +for func in string_return_attrs: + setattr( + StringViewAttrs, + f"resolve_{func}", + create_binary_attr(func, udf_string), + ) - resolve_isalpha = create_identifier_attr("isalpha", types.boolean) - resolve_isalnum = create_identifier_attr("isalnum", types.boolean) - resolve_isdecimal = create_identifier_attr("isdecimal", types.boolean) - resolve_isdigit = create_identifier_attr("isdigit", types.boolean) - resolve_isupper = create_identifier_attr("isupper", types.boolean) - resolve_islower = create_identifier_attr("islower", types.boolean) - resolve_isspace = create_identifier_attr("isspace", types.boolean) - resolve_isnumeric = create_identifier_attr("isnumeric", types.boolean) - resolve_istitle = create_identifier_attr("istitle", types.boolean) - resolve_upper = create_identifier_attr("upper", udf_string) - resolve_lower = create_identifier_attr("lower", udf_string) +for func in int_binary_funcs: + setattr( + StringViewAttrs, f"resolve_{func}", create_binary_attr(func, size_type) + ) - def resolve_count(self, mod): - return types.BoundFunction(StringViewCount, string_view) +for func in id_unary_funcs: + setattr( + StringViewAttrs, + f"resolve_{func}", + create_identifier_attr(func, types.boolean), + ) - def resolve_replace(self, mod): - return types.BoundFunction(StringViewReplace, string_view) +for func in string_unary_funcs: + setattr( + StringViewAttrs, + f"resolve_{func}", + create_identifier_attr(func, udf_string), + ) @cuda_decl_registry.register_attr diff --git a/python/cudf/cudf/tests/test_string_udfs.py b/python/cudf/cudf/tests/test_string_udfs.py index c03c52f4520..049dfdc8e30 100644 --- a/python/cudf/cudf/tests/test_string_udfs.py +++ b/python/cudf/cudf/tests/test_string_udfs.py @@ -26,9 +26,11 @@ def get_kernels(func, dtype, size): """ - Create a kernel for testing a single scalar string function + Create two kernels for testing a single scalar string function. + The first tests the function's action on a string_view object and + the second tests the same except using a udf_string object. Allocates an output vector with a dtype specified by the caller - The returned kernel executes the input function on each data + The returned kernels execute the input function on each data element of the input and returns the output into the output vector """