From aa13b955fa079dc1f1d526bb25a11bd3cb1576d8 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Thu, 17 Nov 2022 00:39:08 -0600 Subject: [PATCH] Support `upper` and `lower` in `strings_udf` (#12099) This PR adds support for the following two functions in `strings_udf`: - `str.upper()` - `str.lower()` Part of https://github.com/rapidsai/cudf/issues/9639 Authors: - https://github.com/brandon-b-miller - David Wendt (https://github.com/davidwendt) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - Lawrence Mitchell (https://github.com/wence-) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/12099 --- python/cudf/cudf/core/udf/strings_lowering.py | 58 +++++++---- python/cudf/cudf/core/udf/strings_typing.py | 14 ++- python/cudf/cudf/tests/test_udf_masked_ops.py | 16 ++++ .../strings_udf/cpp/src/strings/udf/shim.cu | 64 +++++++++++-- .../cpp/src/strings/udf/udf_apis.cu | 3 +- .../strings_udf/_lib/cpp/strings_udf.pxd | 4 +- .../strings_udf/strings_udf/_lib/tables.pyx | 16 +++- python/strings_udf/strings_udf/_typing.py | 8 ++ python/strings_udf/strings_udf/lowering.py | 95 ++++++++++++++++++- .../strings_udf/tests/test_string_udfs.py | 14 +++ 10 files changed, 254 insertions(+), 38 deletions(-) diff --git a/python/cudf/cudf/core/udf/strings_lowering.py b/python/cudf/cudf/core/udf/strings_lowering.py index fdfd013bad7..465866cdd55 100644 --- a/python/cudf/cudf/core/udf/strings_lowering.py +++ b/python/cudf/cudf/core/udf/strings_lowering.py @@ -22,11 +22,13 @@ istitle_impl, isupper_impl, len_impl, + lower_impl, lstrip_impl, rfind_impl, rstrip_impl, startswith_impl, strip_impl, + upper_impl, ) from cudf.core.udf.masked_typing import MaskedType @@ -82,25 +84,6 @@ def masked_binary_func_impl(context, builder, sig, args): ) -create_binary_string_func("MaskedType.strip", strip_impl, udf_string) - -create_binary_string_func("MaskedType.lstrip", lstrip_impl, udf_string) - -create_binary_string_func("MaskedType.rstrip", rstrip_impl, udf_string) - - -create_binary_string_func( - "MaskedType.startswith", - startswith_impl, - types.boolean, -) -create_binary_string_func("MaskedType.endswith", endswith_impl, types.boolean) -create_binary_string_func("MaskedType.find", find_impl, size_type) -create_binary_string_func("MaskedType.rfind", rfind_impl, size_type) -create_binary_string_func("MaskedType.count", count_impl, size_type) -create_binary_string_func(operator.contains, contains_impl, types.boolean) - - def create_masked_unary_identifier_func(op, cuda_func): """ Provide a wrapper around numba's low-level extension API which @@ -127,6 +110,41 @@ def masked_unary_func_impl(context, builder, sig, args): cuda_lower(op, MaskedType(string_view))(masked_unary_func_impl) +def create_masked_upper_or_lower(op, cuda_func): + def upper_or_lower_impl(context, builder, sig, args): + ret = cgutils.create_struct_proxy(sig.return_type)(context, builder) + masked_str = cgutils.create_struct_proxy(sig.args[0])( + context, builder, value=args[0] + ) + + result = cuda_func( + context, + builder, + udf_string(string_view), + (masked_str.value,), + ) + ret.value = result + ret.valid = masked_str.valid + return ret._getvalue() + + cuda_lower(op, MaskedType(string_view))(upper_or_lower_impl) + + +create_binary_string_func("MaskedType.strip", strip_impl, udf_string) +create_binary_string_func("MaskedType.lstrip", lstrip_impl, udf_string) +create_binary_string_func("MaskedType.rstrip", rstrip_impl, udf_string) +create_binary_string_func( + "MaskedType.startswith", + startswith_impl, + types.boolean, +) +create_binary_string_func("MaskedType.endswith", endswith_impl, types.boolean) +create_binary_string_func("MaskedType.find", find_impl, size_type) +create_binary_string_func("MaskedType.rfind", rfind_impl, size_type) +create_binary_string_func("MaskedType.count", count_impl, size_type) +create_binary_string_func(operator.contains, contains_impl, types.boolean) + + create_masked_unary_identifier_func("MaskedType.isalnum", isalnum_impl) create_masked_unary_identifier_func("MaskedType.isalpha", isalpha_impl) create_masked_unary_identifier_func("MaskedType.isdigit", isdigit_impl) @@ -135,3 +153,5 @@ def masked_unary_func_impl(context, builder, sig, args): create_masked_unary_identifier_func("MaskedType.isspace", isspace_impl) create_masked_unary_identifier_func("MaskedType.isdecimal", isdecimal_impl) create_masked_unary_identifier_func("MaskedType.istitle", istitle_impl) +create_masked_upper_or_lower("MaskedType.upper", upper_impl) +create_masked_upper_or_lower("MaskedType.lower", lower_impl) diff --git a/python/cudf/cudf/core/udf/strings_typing.py b/python/cudf/cudf/core/udf/strings_typing.py index e8a35c12f71..87500cba564 100644 --- a/python/cudf/cudf/core/udf/strings_typing.py +++ b/python/cudf/cudf/core/udf/strings_typing.py @@ -14,6 +14,7 @@ int_binary_funcs, size_type, string_return_attrs, + string_unary_funcs, string_view, udf_string, ) @@ -123,7 +124,7 @@ def attr(self, mod): return attr -def create_masked_identifier_attr(attrname): +def create_masked_unary_attr(attrname, retty): """ Helper function wrapping numba's low level extension API. Provides the boilerplate needed to register a unary function of a masked @@ -134,7 +135,7 @@ class MaskedStringViewIdentifierAttr(AbstractTemplate): key = attrname def generic(self, args, kws): - return nb_signature(MaskedType(types.boolean), recvr=self.this) + return nb_signature(MaskedType(retty), recvr=self.this) def attr(self, mod): return types.BoundFunction( @@ -195,7 +196,14 @@ def resolve_valid(self, mod): setattr( MaskedStringViewAttrs, f"resolve_{func}", - create_masked_identifier_attr(f"MaskedType.{func}"), + create_masked_unary_attr(f"MaskedType.{func}", types.boolean), + ) + +for func in string_unary_funcs: + setattr( + MaskedStringViewAttrs, + f"resolve_{func}", + create_masked_unary_attr(f"MaskedType.{func}", udf_string), ) cuda_decl_registry.register_attr(MaskedStringViewAttrs) diff --git a/python/cudf/cudf/tests/test_udf_masked_ops.py b/python/cudf/cudf/tests/test_udf_masked_ops.py index fbe6b3f8888..72abc8e9f87 100644 --- a/python/cudf/cudf/tests/test_udf_masked_ops.py +++ b/python/cudf/cudf/tests/test_udf_masked_ops.py @@ -903,6 +903,22 @@ def func(row): run_masked_udf_test(func, str_udf_data, check_dtype=False) +@string_udf_test +def test_string_udf_upper(str_udf_data): + def func(row): + return row["str_col"].upper() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + +@string_udf_test +def test_string_udf_lower(str_udf_data): + def func(row): + return row["str_col"].lower() + + run_masked_udf_test(func, str_udf_data, check_dtype=False) + + @string_udf_test @pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"]) def test_string_udf_concat(str_udf_data, concat_char): diff --git a/python/strings_udf/cpp/src/strings/udf/shim.cu b/python/strings_udf/cpp/src/strings/udf/shim.cu index 8fc158d7eb7..c5a446c9518 100644 --- a/python/strings_udf/cpp/src/strings/udf/shim.cu +++ b/python/strings_udf/cpp/src/strings/udf/shim.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -128,7 +129,7 @@ extern "C" __device__ int lt(bool* nb_retval, void const* str, void const* rhs) return 0; } -extern "C" __device__ int pyislower(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyislower(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -137,7 +138,7 @@ extern "C" __device__ int pyislower(bool* nb_retval, void const* str, std::int64 return 0; } -extern "C" __device__ int pyisupper(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyisupper(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -146,7 +147,7 @@ extern "C" __device__ int pyisupper(bool* nb_retval, void const* str, std::int64 return 0; } -extern "C" __device__ int pyisspace(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyisspace(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -155,7 +156,7 @@ extern "C" __device__ int pyisspace(bool* nb_retval, void const* str, std::int64 return 0; } -extern "C" __device__ int pyisdecimal(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyisdecimal(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -164,7 +165,7 @@ extern "C" __device__ int pyisdecimal(bool* nb_retval, void const* str, std::int return 0; } -extern "C" __device__ int pyisnumeric(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyisnumeric(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -173,7 +174,7 @@ extern "C" __device__ int pyisnumeric(bool* nb_retval, void const* str, std::int return 0; } -extern "C" __device__ int pyisdigit(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyisdigit(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -182,7 +183,7 @@ extern "C" __device__ int pyisdigit(bool* nb_retval, void const* str, std::int64 return 0; } -extern "C" __device__ int pyisalnum(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyisalnum(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -191,7 +192,7 @@ extern "C" __device__ int pyisalnum(bool* nb_retval, void const* str, std::int64 return 0; } -extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -200,7 +201,7 @@ extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::int64 return 0; } -extern "C" __device__ int pyistitle(bool* nb_retval, void const* str, std::int64_t chars_table) +extern "C" __device__ int pyistitle(bool* nb_retval, void const* str, std::uintptr_t chars_table) { auto str_view = reinterpret_cast(str); @@ -270,6 +271,51 @@ extern "C" __device__ int rstrip(int* nb_retval, return 0; } +extern "C" __device__ int upper(int* nb_retval, + void* udf_str, + void const* st, + std::uintptr_t flags_table, + std::uintptr_t cases_table, + std::uintptr_t special_table) +{ + auto udf_str_ptr = new (udf_str) udf_string; + auto st_ptr = reinterpret_cast(st); + + auto flags_table_ptr = + reinterpret_cast(flags_table); + auto cases_table_ptr = + reinterpret_cast(cases_table); + auto special_table_ptr = + reinterpret_cast(special_table); + + cudf::strings::udf::chars_tables tables{flags_table_ptr, cases_table_ptr, special_table_ptr}; + + *udf_str_ptr = to_upper(tables, *st_ptr); + + return 0; +} + +extern "C" __device__ int lower(int* nb_retval, + void* udf_str, + void const* st, + std::uintptr_t flags_table, + std::uintptr_t cases_table, + std::uintptr_t special_table) +{ + auto udf_str_ptr = new (udf_str) udf_string; + auto st_ptr = reinterpret_cast(st); + + auto flags_table_ptr = + reinterpret_cast(flags_table); + auto cases_table_ptr = + reinterpret_cast(cases_table); + auto special_table_ptr = + reinterpret_cast(special_table); + + cudf::strings::udf::chars_tables tables{flags_table_ptr, cases_table_ptr, special_table_ptr}; + *udf_str_ptr = to_lower(tables, *st_ptr); + return 0; +} extern "C" __device__ int concat(int* nb_retval, void* udf_str, void* const* lhs, void* const* rhs) { diff --git a/python/strings_udf/cpp/src/strings/udf/udf_apis.cu b/python/strings_udf/cpp/src/strings/udf/udf_apis.cu index b4d5014d9e0..3e6491e32e7 100644 --- a/python/strings_udf/cpp/src/strings/udf/udf_apis.cu +++ b/python/strings_udf/cpp/src/strings/udf/udf_apis.cu @@ -42,7 +42,8 @@ namespace { struct udf_string_to_string_view_transform_fn { __device__ cudf::string_view operator()(cudf::strings::udf::udf_string const& dstr) { - return cudf::string_view{dstr.data(), dstr.size_bytes()}; + return dstr.data() == nullptr ? cudf::string_view{} + : cudf::string_view{dstr.data(), dstr.size_bytes()}; } }; diff --git a/python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd b/python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd index 7b90760abcc..b3bf6465db6 100644 --- a/python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd +++ b/python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd @@ -1,6 +1,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. -from libc.stdint cimport uint8_t +from libc.stdint cimport uint8_t, uint16_t from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.vector cimport vector @@ -28,3 +28,5 @@ cdef extern from "cudf/strings/udf/udf_apis.hpp" namespace \ cdef extern from "cudf/strings/detail/char_tables.hpp" namespace \ "cudf::strings::detail" nogil: cdef const uint8_t* get_character_flags_table() except + + cdef const uint16_t* get_character_cases_table() except + + cdef const void* get_special_case_mapping_table() except + diff --git a/python/strings_udf/strings_udf/_lib/tables.pyx b/python/strings_udf/strings_udf/_lib/tables.pyx index 5443364a4a7..6442a34f63f 100644 --- a/python/strings_udf/strings_udf/_lib/tables.pyx +++ b/python/strings_udf/strings_udf/_lib/tables.pyx @@ -1,9 +1,11 @@ # Copyright (c) 2022, NVIDIA CORPORATION. -from libc.stdint cimport uint8_t, uintptr_t +from libc.stdint cimport uint8_t, uint16_t, uintptr_t from strings_udf._lib.cpp.strings_udf cimport ( + get_character_cases_table as cpp_get_character_cases_table, get_character_flags_table as cpp_get_character_flags_table, + get_special_case_mapping_table as cpp_get_special_case_mapping_table, ) import numpy as np @@ -11,4 +13,14 @@ import numpy as np def get_character_flags_table_ptr(): cdef const uint8_t* tbl_ptr = cpp_get_character_flags_table() - return np.int64(tbl_ptr) + return np.uintp(tbl_ptr) + + +def get_character_cases_table_ptr(): + cdef const uint16_t* tbl_ptr = cpp_get_character_cases_table() + return np.uintp(tbl_ptr) + + +def get_special_case_mapping_table_ptr(): + cdef const void* tbl_ptr = cpp_get_special_case_mapping_table() + return np.uintp(tbl_ptr) diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index b678db88b95..3fadf030ce9 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -234,6 +234,7 @@ def resolve_count(self, mod): "isnumeric", "istitle", ] +string_unary_funcs = ["upper", "lower"] string_return_attrs = ["strip", "lstrip", "rstrip"] for func in bool_binary_funcs: @@ -263,4 +264,11 @@ def resolve_count(self, mod): create_identifier_attr(func, types.boolean), ) +for func in string_unary_funcs: + setattr( + StringViewAttrs, + f"resolve_{func}", + create_identifier_attr(func, udf_string), + ) + cuda_decl_registry.register_attr(StringViewAttrs) diff --git a/python/strings_udf/strings_udf/lowering.py b/python/strings_udf/strings_udf/lowering.py index 9e34b61e6da..cca3066a844 100644 --- a/python/strings_udf/strings_udf/lowering.py +++ b/python/strings_udf/strings_udf/lowering.py @@ -13,10 +13,16 @@ registry as cuda_lowering_registry, ) -from strings_udf._lib.tables import get_character_flags_table_ptr +from strings_udf._lib.tables import ( + get_character_cases_table_ptr, + get_character_flags_table_ptr, + get_special_case_mapping_table_ptr, +) from strings_udf._typing import size_type, string_view, udf_string character_flags_table_ptr = get_character_flags_table_ptr() +character_cases_table_ptr = get_character_cases_table_ptr() +special_case_mapping_table_ptr = get_special_case_mapping_table_ptr() _STR_VIEW_PTR = types.CPointer(string_view) _UDF_STRING_PTR = types.CPointer(udf_string) @@ -76,6 +82,19 @@ def _declare_strip_func(name): ) +def _declare_upper_or_lower(func): + return cuda.declare_device( + func, + types.void( + _UDF_STRING_PTR, + _STR_VIEW_PTR, + types.uintp, + types.uintp, + types.uintp, + ), + ) + + _string_view_isdigit = _declare_bool_str_int_func("pyisdigit") _string_view_isalnum = _declare_bool_str_int_func("pyisalnum") _string_view_isalpha = _declare_bool_str_int_func("pyisalpha") @@ -85,6 +104,8 @@ def _declare_strip_func(name): _string_view_isupper = _declare_bool_str_int_func("pyisupper") _string_view_islower = _declare_bool_str_int_func("pyislower") _string_view_istitle = _declare_bool_str_int_func("pyistitle") +_string_view_upper = _declare_upper_or_lower("upper") +_string_view_lower = _declare_upper_or_lower("lower") _string_view_count = cuda.declare_device( @@ -335,12 +356,12 @@ def id_func_impl(context, builder, sig, args): # must be resolved at runtime after context initialization, # therefore cannot be a global variable tbl_ptr = context.get_constant( - types.int64, character_flags_table_ptr + types.uintp, character_flags_table_ptr ) result = context.compile_internal( builder, cuda_func, - nb_signature(types.boolean, _STR_VIEW_PTR, types.int64), + nb_signature(types.boolean, _STR_VIEW_PTR, types.uintp), (str_ptr, tbl_ptr), ) @@ -351,6 +372,74 @@ def id_func_impl(context, builder, sig, args): return deco +def create_upper_or_lower(id_func): + """ + Provide a wrapper around numba's low-level extension API which + produces the boilerplate needed to implement either the upper + or lower attrs of a string view. + """ + + def deco(cuda_func): + @cuda_lower(id_func, string_view) + def id_func_impl(context, builder, sig, args): + str_ptr = builder.alloca(args[0].type) + builder.store(args[0], str_ptr) + + # Lookup table required for conversion functions + # must be resolved at runtime after context initialization, + # therefore cannot be a global variable + flags_tbl_ptr = context.get_constant( + types.uintp, character_flags_table_ptr + ) + cases_tbl_ptr = context.get_constant( + types.uintp, character_cases_table_ptr + ) + special_tbl_ptr = context.get_constant( + types.uintp, special_case_mapping_table_ptr + ) + udf_str_ptr = builder.alloca( + default_manager[udf_string].get_value_type() + ) + + _ = context.compile_internal( + builder, + cuda_func, + types.void( + _UDF_STRING_PTR, + _STR_VIEW_PTR, + types.uintp, + types.uintp, + types.uintp, + ), + ( + udf_str_ptr, + str_ptr, + flags_tbl_ptr, + cases_tbl_ptr, + special_tbl_ptr, + ), + ) + + result = cgutils.create_struct_proxy(udf_string)( + context, builder, value=builder.load(udf_str_ptr) + ) + return result._getvalue() + + return id_func_impl + + return deco + + +@create_upper_or_lower("StringView.upper") +def upper_impl(result, st, flags, cases, special): + return _string_view_upper(result, st, flags, cases, special) + + +@create_upper_or_lower("StringView.lower") +def lower_impl(result, st, flags, cases, special): + return _string_view_lower(result, st, flags, cases, special) + + @create_unary_identifier_func("StringView.isdigit") def isdigit_impl(st, tbl): return _string_view_isdigit(st, tbl) diff --git a/python/strings_udf/strings_udf/tests/test_string_udfs.py b/python/strings_udf/strings_udf/tests/test_string_udfs.py index 49663ee02ec..02c3a8b8c12 100644 --- a/python/strings_udf/strings_udf/tests/test_string_udfs.py +++ b/python/strings_udf/strings_udf/tests/test_string_udfs.py @@ -304,6 +304,20 @@ def func(st): run_udf_test(data, func, "str") +def test_string_udf_upper(data): + def func(st): + return st.upper() + + run_udf_test(data, func, "str") + + +def test_string_udf_lower(data): + def func(st): + return st.lower() + + run_udf_test(data, func, "str") + + @pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"]) def test_string_udf_concat(data, concat_char): def func(st):