Skip to content

Commit

Permalink
Allow casting from UDFString back to StringView to call methods i…
Browse files Browse the repository at this point in the history
…n `strings_udf` (#12363)

This PR adds some code to cast a `UDFString` to a `StringView` which unblocks UDFs that end up calling further transformations on strings that have already been returned by other functions. It works by registering a set of attributes to `UDFString` instances that mirror the ones attached to `StringView`, and introducing lowering that allows a cast. The cast ultimately calls a shim function which wraps the `cudf::string_view` casting operator of `udf_string`.

Authors:
  - https://github.com/brandon-b-miller
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #12363
  • Loading branch information
brandon-b-miller authored Mar 9, 2023
1 parent 3048791 commit 52c675a
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 37 deletions.
16 changes: 13 additions & 3 deletions python/cudf/cudf/core/udf/masked_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ def generic(self, args, kws):
# Strings functions and utilities
def _is_valid_string_arg(ty):
return (
isinstance(ty, MaskedType) and isinstance(ty.value_type, StringView)
isinstance(ty, MaskedType)
and isinstance(ty.value_type, (StringView, UDFString))
) or isinstance(ty, types.StringLiteral)


Expand All @@ -465,9 +466,9 @@ class MaskedStringFunction(AbstractTemplate):
@register_masked_string_function(len)
def len_typing(self, args, kws):
if isinstance(args[0], MaskedType) and isinstance(
args[0].value_type, StringView
args[0].value_type, (StringView, UDFString)
):
return nb_signature(MaskedType(size_type), args[0])
return nb_signature(MaskedType(size_type), MaskedType(string_view))
elif isinstance(args[0], types.StringLiteral) and len(args) == 1:
return nb_signature(size_type, args[0])

Expand Down Expand Up @@ -635,4 +636,13 @@ def resolve_valid(self, mod):
create_masked_unary_attr(f"MaskedType.{func}", udf_string),
)


class MaskedUDFStringAttrs(MaskedStringViewAttrs):
key = MaskedType(udf_string)

def resolve_value(self, mod):
return udf_string


cuda_decl_registry.register_attr(MaskedStringViewAttrs)
cuda_decl_registry.register_attr(MaskedUDFStringAttrs)
90 changes: 68 additions & 22 deletions python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,45 @@ def cast_string_view_to_udf_string(context, builder, fromty, toty, val):
return result._getvalue()


@cuda_lowering_registry.lower_cast(udf_string, string_view)
def cast_udf_string_to_string_view(context, builder, fromty, toty, val):
udf_str_ptr = builder.alloca(default_manager[fromty].get_value_type())
sv_ptr = builder.alloca(default_manager[toty].get_value_type())
builder.store(val, udf_str_ptr)

context.compile_internal(
builder,
call_create_string_view_from_udf_string,
nb_signature(types.void, _UDF_STRING_PTR, _STR_VIEW_PTR),
(udf_str_ptr, sv_ptr),
)

result = cgutils.create_struct_proxy(string_view)(
context, builder, value=builder.load(sv_ptr)
)

return result._getvalue()


# utilities
_create_udf_string_from_string_view = cuda.declare_device(
"udf_string_from_string_view",
types.void(types.CPointer(string_view), types.CPointer(udf_string)),
types.void(_STR_VIEW_PTR, _UDF_STRING_PTR),
)
_create_string_view_from_udf_string = cuda.declare_device(
"string_view_from_udf_string",
types.void(_UDF_STRING_PTR, _STR_VIEW_PTR),
)


def call_create_udf_string_from_string_view(sv, udf_str):
_create_udf_string_from_string_view(sv, udf_str)


def call_create_string_view_from_udf_string(udf_str, sv):
_create_string_view_from_udf_string(udf_str, sv)


# String function implementations
def call_len_string_view(st):
return _string_view_len(st)
Expand Down Expand Up @@ -216,6 +244,7 @@ def call_string_view_replace(result, src, to_replace, replacement):


@cuda_lower("StringView.replace", string_view, string_view, string_view)
@cuda_lower("UDFString.replace", string_view, string_view, string_view)
def replace_impl(context, builder, sig, args):
src_ptr = builder.alloca(args[0].type)
to_replace_ptr = builder.alloca(args[1].type)
Expand Down Expand Up @@ -292,6 +321,20 @@ def binary_func_impl(context, builder, sig, args):
)
return result._getvalue()

# binary_func can be attribute-like: str.binary_func
# or operator-like: binary_func(str, other)
if isinstance(binary_func, str):
binary_func_impl = cuda_lower(
f"StringView.{binary_func}", string_view, string_view
)(binary_func_impl)
binary_func_impl = cuda_lower(
f"UDFString.{binary_func}", string_view, string_view
)(binary_func_impl)
else:
binary_func_impl = cuda_lower(
binary_func, string_view, string_view
)(binary_func_impl)

return binary_func_impl

return deco
Expand Down Expand Up @@ -332,42 +375,42 @@ def lt_impl(st, rhs):
return _string_view_lt(st, rhs)


@create_binary_string_func("StringView.strip", udf_string)
@create_binary_string_func("strip", udf_string)
def strip_impl(result, to_strip, strip_char):
return _string_view_strip(result, to_strip, strip_char)


@create_binary_string_func("StringView.lstrip", udf_string)
@create_binary_string_func("lstrip", udf_string)
def lstrip_impl(result, to_strip, strip_char):
return _string_view_lstrip(result, to_strip, strip_char)


@create_binary_string_func("StringView.rstrip", udf_string)
@create_binary_string_func("rstrip", udf_string)
def rstrip_impl(result, to_strip, strip_char):
return _string_view_rstrip(result, to_strip, strip_char)


@create_binary_string_func("StringView.startswith", types.boolean)
@create_binary_string_func("startswith", types.boolean)
def startswith_impl(sv, substr):
return _string_view_startswith(sv, substr)


@create_binary_string_func("StringView.endswith", types.boolean)
@create_binary_string_func("endswith", types.boolean)
def endswith_impl(sv, substr):
return _string_view_endswith(sv, substr)


@create_binary_string_func("StringView.count", size_type)
@create_binary_string_func("count", size_type)
def count_impl(st, substr):
return _string_view_count(st, substr)


@create_binary_string_func("StringView.find", size_type)
@create_binary_string_func("find", size_type)
def find_impl(sv, substr):
return _string_view_find(sv, substr)


@create_binary_string_func("StringView.rfind", size_type)
@create_binary_string_func("rfind", size_type)
def rfind_impl(sv, substr):
return _string_view_rfind(sv, substr)

Expand All @@ -380,7 +423,8 @@ def create_unary_identifier_func(id_func):
"""

def deco(cuda_func):
@cuda_lower(id_func, string_view)
@cuda_lower(f"StringView.{id_func}", string_view)
@cuda_lower(f"UDFString.{id_func}", string_view)
def id_func_impl(context, builder, sig, args):
str_ptr = builder.alloca(args[0].type)
builder.store(args[0], str_ptr)
Expand Down Expand Up @@ -413,7 +457,8 @@ def create_upper_or_lower(id_func):
"""

def deco(cuda_func):
@cuda_lower(id_func, string_view)
@cuda_lower(f"StringView.{id_func}", string_view)
@cuda_lower(f"UDFString.{id_func}", string_view)
def id_func_impl(context, builder, sig, args):
str_ptr = builder.alloca(args[0].type)
builder.store(args[0], str_ptr)
Expand Down Expand Up @@ -463,62 +508,63 @@ def id_func_impl(context, builder, sig, args):
return deco


@create_upper_or_lower("StringView.upper")
@create_upper_or_lower("upper")
def upper_impl(result, st, flags, cases, special):
return _string_view_upper(result, st, flags, cases, special)


@create_upper_or_lower("StringView.lower")
@create_upper_or_lower("lower")
def lower_impl(result, st, flags, cases, special):
return _string_view_lower(result, st, flags, cases, special)


@create_unary_identifier_func("StringView.isdigit")
@create_unary_identifier_func("isdigit")
def isdigit_impl(st, tbl):
return _string_view_isdigit(st, tbl)


@create_unary_identifier_func("StringView.isalnum")
@create_unary_identifier_func("isalnum")
def isalnum_impl(st, tbl):
return _string_view_isalnum(st, tbl)


@create_unary_identifier_func("StringView.isalpha")
@create_unary_identifier_func("isalpha")
def isalpha_impl(st, tbl):
return _string_view_isalpha(st, tbl)


@create_unary_identifier_func("StringView.isnumeric")
@create_unary_identifier_func("isnumeric")
def isnumeric_impl(st, tbl):
return _string_view_isnumeric(st, tbl)


@create_unary_identifier_func("StringView.isdecimal")
@create_unary_identifier_func("isdecimal")
def isdecimal_impl(st, tbl):
return _string_view_isdecimal(st, tbl)


@create_unary_identifier_func("StringView.isspace")
@create_unary_identifier_func("isspace")
def isspace_impl(st, tbl):
return _string_view_isspace(st, tbl)


@create_unary_identifier_func("StringView.isupper")
@create_unary_identifier_func("isupper")
def isupper_impl(st, tbl):
return _string_view_isupper(st, tbl)


@create_unary_identifier_func("StringView.islower")
@create_unary_identifier_func("islower")
def islower_impl(st, tbl):
return _string_view_islower(st, tbl)


@create_unary_identifier_func("StringView.istitle")
@create_unary_identifier_func("istitle")
def istitle_impl(st, tbl):
return _string_view_istitle(st, tbl)


@cuda_lower(len, MaskedType(string_view))
@cuda_lower(len, MaskedType(udf_string))
def masked_len_impl(context, builder, sig, args):
ret = cgutils.create_struct_proxy(sig.return_type)(context, builder)
masked_sv_ty = sig.args[0]
Expand Down
9 changes: 8 additions & 1 deletion python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def generic(self, args, kws):
# string_view -> int32
# udf_string -> int32
# literal -> int32
return nb_signature(size_type, args[0])
return nb_signature(size_type, string_view)


def register_stringview_binaryop(op, retty):
Expand Down Expand Up @@ -257,7 +257,14 @@ def resolve_replace(self, mod):
create_identifier_attr(func, udf_string),
)


@cuda_decl_registry.register_attr
class UDFStringAttrs(StringViewAttrs):
key = udf_string


cuda_decl_registry.register_attr(StringViewAttrs)
cuda_decl_registry.register_attr(UDFStringAttrs)

register_stringview_binaryop(operator.eq, types.boolean)
register_stringview_binaryop(operator.ne, types.boolean)
Expand Down
39 changes: 39 additions & 0 deletions python/cudf/cudf/testing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
import numpy as np
import pandas as pd
import pytest
from numba.core.typing import signature as nb_signature
from numba.core.typing.templates import AbstractTemplate
from numba.cuda.cudadecl import registry as cuda_decl_registry
from numba.cuda.cudaimpl import lower as cuda_lower
from pandas import testing as tm

import cudf
from cudf._lib.null_mask import bitmask_allocation_size_bytes
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.core.udf.strings_lowering import cast_string_view_to_udf_string
from cudf.core.udf.strings_typing import StringView, string_view, udf_string
from cudf.utils import dtypes as dtypeutils

supported_numpy_dtypes = [
Expand Down Expand Up @@ -387,3 +393,36 @@ def expect_warning_if(condition, warning=FutureWarning, *args, **kwargs):
yield
else:
yield


def sv_to_udf_str(sv):
"""
Cast a string_view object to a udf_string object
This placeholder function never runs in python
It exists only for numba to have something to replace
with the typing and lowering code below
This is similar conceptually to needing a translation
engine to emit an expression in target language "B" when
there is no equivalent in the source language "A" to
translate from. This function effectively defines the
expression in language "A" and the associated typing
and lowering describe the translation process, despite
the expression having no meaning in language "A"
"""
pass


@cuda_decl_registry.register_global(sv_to_udf_str)
class StringViewToUDFStringDecl(AbstractTemplate):
def generic(self, args, kws):
if isinstance(args[0], StringView) and len(args) == 1:
return nb_signature(udf_string, string_view)


@cuda_lower(sv_to_udf_str, string_view)
def sv_to_udf_str_testing_lowering(context, builder, sig, args):
return cast_string_view_to_udf_string(
context, builder, sig.args[0], sig.return_type, args[0]
)
Loading

0 comments on commit 52c675a

Please sign in to comment.