Skip to content

Commit

Permalink
Support upper and lower in strings_udf (#12099)
Browse files Browse the repository at this point in the history
This PR adds support for the following two functions in `strings_udf`:

- `str.upper()`
- `str.lower()`

Part of #9639

Authors:
  - https://github.com/brandon-b-miller
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Lawrence Mitchell (https://github.com/wence-)
  - David Wendt (https://github.com/davidwendt)

URL: #12099
  • Loading branch information
brandon-b-miller authored Nov 17, 2022
1 parent 6de2c4e commit aa13b95
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 38 deletions.
58 changes: 39 additions & 19 deletions python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
istitle_impl,
isupper_impl,
len_impl,
lower_impl,
lstrip_impl,
rfind_impl,
rstrip_impl,
startswith_impl,
strip_impl,
upper_impl,
)

from cudf.core.udf.masked_typing import MaskedType
Expand Down Expand Up @@ -82,25 +84,6 @@ def masked_binary_func_impl(context, builder, sig, args):
)


create_binary_string_func("MaskedType.strip", strip_impl, udf_string)

create_binary_string_func("MaskedType.lstrip", lstrip_impl, udf_string)

create_binary_string_func("MaskedType.rstrip", rstrip_impl, udf_string)


create_binary_string_func(
"MaskedType.startswith",
startswith_impl,
types.boolean,
)
create_binary_string_func("MaskedType.endswith", endswith_impl, types.boolean)
create_binary_string_func("MaskedType.find", find_impl, size_type)
create_binary_string_func("MaskedType.rfind", rfind_impl, size_type)
create_binary_string_func("MaskedType.count", count_impl, size_type)
create_binary_string_func(operator.contains, contains_impl, types.boolean)


def create_masked_unary_identifier_func(op, cuda_func):
"""
Provide a wrapper around numba's low-level extension API which
Expand All @@ -127,6 +110,41 @@ def masked_unary_func_impl(context, builder, sig, args):
cuda_lower(op, MaskedType(string_view))(masked_unary_func_impl)


def create_masked_upper_or_lower(op, cuda_func):
def upper_or_lower_impl(context, builder, sig, args):
ret = cgutils.create_struct_proxy(sig.return_type)(context, builder)
masked_str = cgutils.create_struct_proxy(sig.args[0])(
context, builder, value=args[0]
)

result = cuda_func(
context,
builder,
udf_string(string_view),
(masked_str.value,),
)
ret.value = result
ret.valid = masked_str.valid
return ret._getvalue()

cuda_lower(op, MaskedType(string_view))(upper_or_lower_impl)


create_binary_string_func("MaskedType.strip", strip_impl, udf_string)
create_binary_string_func("MaskedType.lstrip", lstrip_impl, udf_string)
create_binary_string_func("MaskedType.rstrip", rstrip_impl, udf_string)
create_binary_string_func(
"MaskedType.startswith",
startswith_impl,
types.boolean,
)
create_binary_string_func("MaskedType.endswith", endswith_impl, types.boolean)
create_binary_string_func("MaskedType.find", find_impl, size_type)
create_binary_string_func("MaskedType.rfind", rfind_impl, size_type)
create_binary_string_func("MaskedType.count", count_impl, size_type)
create_binary_string_func(operator.contains, contains_impl, types.boolean)


create_masked_unary_identifier_func("MaskedType.isalnum", isalnum_impl)
create_masked_unary_identifier_func("MaskedType.isalpha", isalpha_impl)
create_masked_unary_identifier_func("MaskedType.isdigit", isdigit_impl)
Expand All @@ -135,3 +153,5 @@ def masked_unary_func_impl(context, builder, sig, args):
create_masked_unary_identifier_func("MaskedType.isspace", isspace_impl)
create_masked_unary_identifier_func("MaskedType.isdecimal", isdecimal_impl)
create_masked_unary_identifier_func("MaskedType.istitle", istitle_impl)
create_masked_upper_or_lower("MaskedType.upper", upper_impl)
create_masked_upper_or_lower("MaskedType.lower", lower_impl)
14 changes: 11 additions & 3 deletions python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
int_binary_funcs,
size_type,
string_return_attrs,
string_unary_funcs,
string_view,
udf_string,
)
Expand Down Expand Up @@ -123,7 +124,7 @@ def attr(self, mod):
return attr


def create_masked_identifier_attr(attrname):
def create_masked_unary_attr(attrname, retty):
"""
Helper function wrapping numba's low level extension API. Provides
the boilerplate needed to register a unary function of a masked
Expand All @@ -134,7 +135,7 @@ class MaskedStringViewIdentifierAttr(AbstractTemplate):
key = attrname

def generic(self, args, kws):
return nb_signature(MaskedType(types.boolean), recvr=self.this)
return nb_signature(MaskedType(retty), recvr=self.this)

def attr(self, mod):
return types.BoundFunction(
Expand Down Expand Up @@ -195,7 +196,14 @@ def resolve_valid(self, mod):
setattr(
MaskedStringViewAttrs,
f"resolve_{func}",
create_masked_identifier_attr(f"MaskedType.{func}"),
create_masked_unary_attr(f"MaskedType.{func}", types.boolean),
)

for func in string_unary_funcs:
setattr(
MaskedStringViewAttrs,
f"resolve_{func}",
create_masked_unary_attr(f"MaskedType.{func}", udf_string),
)

cuda_decl_registry.register_attr(MaskedStringViewAttrs)
16 changes: 16 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,22 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
def test_string_udf_upper(str_udf_data):
def func(row):
return row["str_col"].upper()

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
def test_string_udf_lower(str_udf_data):
def func(row):
return row["str_col"].lower()

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
@pytest.mark.parametrize("concat_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_concat(str_udf_data, concat_char):
Expand Down
64 changes: 55 additions & 9 deletions python/strings_udf/cpp/src/strings/udf/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cudf/strings/udf/case.cuh>
#include <cudf/strings/udf/char_types.cuh>
#include <cudf/strings/udf/search.cuh>
#include <cudf/strings/udf/starts_with.cuh>
Expand Down Expand Up @@ -128,7 +129,7 @@ extern "C" __device__ int lt(bool* nb_retval, void const* str, void const* rhs)
return 0;
}

extern "C" __device__ int pyislower(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyislower(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -137,7 +138,7 @@ extern "C" __device__ int pyislower(bool* nb_retval, void const* str, std::int64
return 0;
}

extern "C" __device__ int pyisupper(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyisupper(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -146,7 +147,7 @@ extern "C" __device__ int pyisupper(bool* nb_retval, void const* str, std::int64
return 0;
}

extern "C" __device__ int pyisspace(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyisspace(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -155,7 +156,7 @@ extern "C" __device__ int pyisspace(bool* nb_retval, void const* str, std::int64
return 0;
}

extern "C" __device__ int pyisdecimal(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyisdecimal(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -164,7 +165,7 @@ extern "C" __device__ int pyisdecimal(bool* nb_retval, void const* str, std::int
return 0;
}

extern "C" __device__ int pyisnumeric(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyisnumeric(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -173,7 +174,7 @@ extern "C" __device__ int pyisnumeric(bool* nb_retval, void const* str, std::int
return 0;
}

extern "C" __device__ int pyisdigit(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyisdigit(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -182,7 +183,7 @@ extern "C" __device__ int pyisdigit(bool* nb_retval, void const* str, std::int64
return 0;
}

extern "C" __device__ int pyisalnum(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyisalnum(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -191,7 +192,7 @@ extern "C" __device__ int pyisalnum(bool* nb_retval, void const* str, std::int64
return 0;
}

extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand All @@ -200,7 +201,7 @@ extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::int64
return 0;
}

extern "C" __device__ int pyistitle(bool* nb_retval, void const* str, std::int64_t chars_table)
extern "C" __device__ int pyistitle(bool* nb_retval, void const* str, std::uintptr_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

Expand Down Expand Up @@ -270,6 +271,51 @@ extern "C" __device__ int rstrip(int* nb_retval,

return 0;
}
extern "C" __device__ int upper(int* nb_retval,
void* udf_str,
void const* st,
std::uintptr_t flags_table,
std::uintptr_t cases_table,
std::uintptr_t special_table)
{
auto udf_str_ptr = new (udf_str) udf_string;
auto st_ptr = reinterpret_cast<cudf::string_view const*>(st);

auto flags_table_ptr =
reinterpret_cast<cudf::strings::detail::character_flags_table_type*>(flags_table);
auto cases_table_ptr =
reinterpret_cast<cudf::strings::detail::character_cases_table_type*>(cases_table);
auto special_table_ptr =
reinterpret_cast<cudf::strings::detail::special_case_mapping*>(special_table);

cudf::strings::udf::chars_tables tables{flags_table_ptr, cases_table_ptr, special_table_ptr};

*udf_str_ptr = to_upper(tables, *st_ptr);

return 0;
}

extern "C" __device__ int lower(int* nb_retval,
void* udf_str,
void const* st,
std::uintptr_t flags_table,
std::uintptr_t cases_table,
std::uintptr_t special_table)
{
auto udf_str_ptr = new (udf_str) udf_string;
auto st_ptr = reinterpret_cast<cudf::string_view const*>(st);

auto flags_table_ptr =
reinterpret_cast<cudf::strings::detail::character_flags_table_type*>(flags_table);
auto cases_table_ptr =
reinterpret_cast<cudf::strings::detail::character_cases_table_type*>(cases_table);
auto special_table_ptr =
reinterpret_cast<cudf::strings::detail::special_case_mapping*>(special_table);

cudf::strings::udf::chars_tables tables{flags_table_ptr, cases_table_ptr, special_table_ptr};
*udf_str_ptr = to_lower(tables, *st_ptr);
return 0;
}

extern "C" __device__ int concat(int* nb_retval, void* udf_str, void* const* lhs, void* const* rhs)
{
Expand Down
3 changes: 2 additions & 1 deletion python/strings_udf/cpp/src/strings/udf/udf_apis.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ namespace {
struct udf_string_to_string_view_transform_fn {
__device__ cudf::string_view operator()(cudf::strings::udf::udf_string const& dstr)
{
return cudf::string_view{dstr.data(), dstr.size_bytes()};
return dstr.data() == nullptr ? cudf::string_view{}
: cudf::string_view{dstr.data(), dstr.size_bytes()};
}
};

Expand Down
4 changes: 3 additions & 1 deletion python/strings_udf/strings_udf/_lib/cpp/strings_udf.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2022, NVIDIA CORPORATION.

from libc.stdint cimport uint8_t
from libc.stdint cimport uint8_t, uint16_t
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.vector cimport vector
Expand Down Expand Up @@ -28,3 +28,5 @@ cdef extern from "cudf/strings/udf/udf_apis.hpp" namespace \
cdef extern from "cudf/strings/detail/char_tables.hpp" namespace \
"cudf::strings::detail" nogil:
cdef const uint8_t* get_character_flags_table() except +
cdef const uint16_t* get_character_cases_table() except +
cdef const void* get_special_case_mapping_table() except +
16 changes: 14 additions & 2 deletions python/strings_udf/strings_udf/_lib/tables.pyx
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# Copyright (c) 2022, NVIDIA CORPORATION.

from libc.stdint cimport uint8_t, uintptr_t
from libc.stdint cimport uint8_t, uint16_t, uintptr_t

from strings_udf._lib.cpp.strings_udf cimport (
get_character_cases_table as cpp_get_character_cases_table,
get_character_flags_table as cpp_get_character_flags_table,
get_special_case_mapping_table as cpp_get_special_case_mapping_table,
)

import numpy as np


def get_character_flags_table_ptr():
cdef const uint8_t* tbl_ptr = cpp_get_character_flags_table()
return np.int64(<uintptr_t>tbl_ptr)
return np.uintp(<uintptr_t>tbl_ptr)


def get_character_cases_table_ptr():
cdef const uint16_t* tbl_ptr = cpp_get_character_cases_table()
return np.uintp(<uintptr_t>tbl_ptr)


def get_special_case_mapping_table_ptr():
cdef const void* tbl_ptr = cpp_get_special_case_mapping_table()
return np.uintp(<uintptr_t>tbl_ptr)
8 changes: 8 additions & 0 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def resolve_count(self, mod):
"isnumeric",
"istitle",
]
string_unary_funcs = ["upper", "lower"]
string_return_attrs = ["strip", "lstrip", "rstrip"]

for func in bool_binary_funcs:
Expand Down Expand Up @@ -263,4 +264,11 @@ def resolve_count(self, mod):
create_identifier_attr(func, types.boolean),
)

for func in string_unary_funcs:
setattr(
StringViewAttrs,
f"resolve_{func}",
create_identifier_attr(func, udf_string),
)

cuda_decl_registry.register_attr(StringViewAttrs)
Loading

0 comments on commit aa13b95

Please sign in to comment.