Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow casting from UDFString back to StringView to call methods in strings_udf #12363

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
477155d
initially working up through count, work to do
brandon-b-miller Dec 12, 2022
08e190e
replace standalone function works in strings_udf
brandon-b-miller Dec 13, 2022
c1bd41c
strings_udf tests pass
brandon-b-miller Dec 15, 2022
3f043d6
all tests pass
brandon-b-miller Dec 16, 2022
c9fa690
merge latest and resolve conflicts
brandon-b-miller Jan 5, 2023
70109dc
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 5, 2023
8c8eb85
update copyright years
brandon-b-miller Jan 5, 2023
6049d53
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 9, 2023
eb44719
make things much easier
brandon-b-miller Jan 9, 2023
59b42ea
refactor
brandon-b-miller Jan 9, 2023
599033f
continue reverting changes
brandon-b-miller Jan 9, 2023
74d6c61
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 17, 2023
d0f4e3b
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
galipremsagar Jan 25, 2023
6b3b085
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 30, 2023
b163c1b
cast in len
brandon-b-miller Jan 30, 2023
21a2121
small typing bug
brandon-b-miller Jan 30, 2023
8a55a26
remove duplicate code
brandon-b-miller Jan 30, 2023
0d01928
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
brandon-b-miller Feb 2, 2023
4043691
merge latest, resolve conflicts, pass tests, refactor
brandon-b-miller Feb 22, 2023
26a6f88
remove old file
brandon-b-miller Feb 22, 2023
888f4dd
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
brandon-b-miller Feb 28, 2023
254a583
address reviews
brandon-b-miller Mar 1, 2023
0840cdc
add docs to sv_to_udf_str
brandon-b-miller Mar 1, 2023
c9ff968
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
vyasr Mar 6, 2023
e6ae995
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
vyasr Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions python/cudf/cudf/core/udf/masked_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ def generic(self, args, kws):
# Strings functions and utilities
def _is_valid_string_arg(ty):
return (
isinstance(ty, MaskedType) and isinstance(ty.value_type, StringView)
isinstance(ty, MaskedType)
and isinstance(ty.value_type, (StringView, UDFString))
) or isinstance(ty, types.StringLiteral)


Expand All @@ -465,9 +466,9 @@ class MaskedStringFunction(AbstractTemplate):
@register_masked_string_function(len)
def len_typing(self, args, kws):
if isinstance(args[0], MaskedType) and isinstance(
args[0].value_type, StringView
args[0].value_type, (StringView, UDFString)
):
return nb_signature(MaskedType(size_type), args[0])
return nb_signature(MaskedType(size_type), MaskedType(string_view))
elif isinstance(args[0], types.StringLiteral) and len(args) == 1:
return nb_signature(size_type, args[0])

Expand Down Expand Up @@ -635,4 +636,13 @@ def resolve_valid(self, mod):
create_masked_unary_attr(f"MaskedType.{func}", udf_string),
)


class MaskedUDFStringAttrs(MaskedStringViewAttrs):
key = MaskedType(udf_string)

def resolve_value(self, mod):
return udf_string


cuda_decl_registry.register_attr(MaskedStringViewAttrs)
cuda_decl_registry.register_attr(MaskedUDFStringAttrs)
90 changes: 68 additions & 22 deletions python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,45 @@ def cast_string_view_to_udf_string(context, builder, fromty, toty, val):
return result._getvalue()


@cuda_lowering_registry.lower_cast(udf_string, string_view)
def cast_udf_string_to_string_view(context, builder, fromty, toty, val):
udf_str_ptr = builder.alloca(default_manager[fromty].get_value_type())
sv_ptr = builder.alloca(default_manager[toty].get_value_type())
builder.store(val, udf_str_ptr)

context.compile_internal(
builder,
call_create_string_view_from_udf_string,
nb_signature(types.void, _UDF_STRING_PTR, _STR_VIEW_PTR),
(udf_str_ptr, sv_ptr),
)

result = cgutils.create_struct_proxy(string_view)(
context, builder, value=builder.load(sv_ptr)
)

return result._getvalue()


# utilities
_create_udf_string_from_string_view = cuda.declare_device(
"udf_string_from_string_view",
types.void(types.CPointer(string_view), types.CPointer(udf_string)),
types.void(_STR_VIEW_PTR, _UDF_STRING_PTR),
)
_create_string_view_from_udf_string = cuda.declare_device(
"string_view_from_udf_string",
types.void(_UDF_STRING_PTR, _STR_VIEW_PTR),
)


def call_create_udf_string_from_string_view(sv, udf_str):
_create_udf_string_from_string_view(sv, udf_str)


def call_create_string_view_from_udf_string(udf_str, sv):
_create_string_view_from_udf_string(udf_str, sv)


# String function implementations
def call_len_string_view(st):
return _string_view_len(st)
Expand Down Expand Up @@ -216,6 +244,7 @@ def call_string_view_replace(result, src, to_replace, replacement):


@cuda_lower("StringView.replace", string_view, string_view, string_view)
@cuda_lower("UDFString.replace", string_view, string_view, string_view)
def replace_impl(context, builder, sig, args):
src_ptr = builder.alloca(args[0].type)
to_replace_ptr = builder.alloca(args[1].type)
Expand Down Expand Up @@ -292,6 +321,20 @@ def binary_func_impl(context, builder, sig, args):
)
return result._getvalue()

# binary_func can be attribute-like: str.binary_func
# or operator-like: binary_func(str, other)
if isinstance(binary_func, str):
binary_func_impl = cuda_lower(
f"StringView.{binary_func}", string_view, string_view
)(binary_func_impl)
binary_func_impl = cuda_lower(
f"UDFString.{binary_func}", string_view, string_view
)(binary_func_impl)
else:
binary_func_impl = cuda_lower(
binary_func, string_view, string_view
)(binary_func_impl)

return binary_func_impl

return deco
Expand Down Expand Up @@ -332,42 +375,42 @@ def lt_impl(st, rhs):
return _string_view_lt(st, rhs)


@create_binary_string_func("StringView.strip", udf_string)
@create_binary_string_func("strip", udf_string)
def strip_impl(result, to_strip, strip_char):
return _string_view_strip(result, to_strip, strip_char)


@create_binary_string_func("StringView.lstrip", udf_string)
@create_binary_string_func("lstrip", udf_string)
def lstrip_impl(result, to_strip, strip_char):
return _string_view_lstrip(result, to_strip, strip_char)


@create_binary_string_func("StringView.rstrip", udf_string)
@create_binary_string_func("rstrip", udf_string)
def rstrip_impl(result, to_strip, strip_char):
return _string_view_rstrip(result, to_strip, strip_char)


@create_binary_string_func("StringView.startswith", types.boolean)
@create_binary_string_func("startswith", types.boolean)
def startswith_impl(sv, substr):
return _string_view_startswith(sv, substr)


@create_binary_string_func("StringView.endswith", types.boolean)
@create_binary_string_func("endswith", types.boolean)
def endswith_impl(sv, substr):
return _string_view_endswith(sv, substr)


@create_binary_string_func("StringView.count", size_type)
@create_binary_string_func("count", size_type)
def count_impl(st, substr):
return _string_view_count(st, substr)


@create_binary_string_func("StringView.find", size_type)
@create_binary_string_func("find", size_type)
def find_impl(sv, substr):
return _string_view_find(sv, substr)


@create_binary_string_func("StringView.rfind", size_type)
@create_binary_string_func("rfind", size_type)
def rfind_impl(sv, substr):
return _string_view_rfind(sv, substr)

Expand All @@ -380,7 +423,8 @@ def create_unary_identifier_func(id_func):
"""

def deco(cuda_func):
@cuda_lower(id_func, string_view)
@cuda_lower(f"StringView.{id_func}", string_view)
@cuda_lower(f"UDFString.{id_func}", string_view)
def id_func_impl(context, builder, sig, args):
str_ptr = builder.alloca(args[0].type)
builder.store(args[0], str_ptr)
Expand Down Expand Up @@ -413,7 +457,8 @@ def create_upper_or_lower(id_func):
"""

def deco(cuda_func):
@cuda_lower(id_func, string_view)
@cuda_lower(f"StringView.{id_func}", string_view)
@cuda_lower(f"UDFString.{id_func}", string_view)
def id_func_impl(context, builder, sig, args):
str_ptr = builder.alloca(args[0].type)
builder.store(args[0], str_ptr)
Expand Down Expand Up @@ -463,62 +508,63 @@ def id_func_impl(context, builder, sig, args):
return deco


@create_upper_or_lower("StringView.upper")
@create_upper_or_lower("upper")
def upper_impl(result, st, flags, cases, special):
return _string_view_upper(result, st, flags, cases, special)


@create_upper_or_lower("StringView.lower")
@create_upper_or_lower("lower")
def lower_impl(result, st, flags, cases, special):
return _string_view_lower(result, st, flags, cases, special)


@create_unary_identifier_func("StringView.isdigit")
@create_unary_identifier_func("isdigit")
def isdigit_impl(st, tbl):
return _string_view_isdigit(st, tbl)


@create_unary_identifier_func("StringView.isalnum")
@create_unary_identifier_func("isalnum")
def isalnum_impl(st, tbl):
return _string_view_isalnum(st, tbl)


@create_unary_identifier_func("StringView.isalpha")
@create_unary_identifier_func("isalpha")
def isalpha_impl(st, tbl):
return _string_view_isalpha(st, tbl)


@create_unary_identifier_func("StringView.isnumeric")
@create_unary_identifier_func("isnumeric")
def isnumeric_impl(st, tbl):
return _string_view_isnumeric(st, tbl)


@create_unary_identifier_func("StringView.isdecimal")
@create_unary_identifier_func("isdecimal")
def isdecimal_impl(st, tbl):
return _string_view_isdecimal(st, tbl)


@create_unary_identifier_func("StringView.isspace")
@create_unary_identifier_func("isspace")
def isspace_impl(st, tbl):
return _string_view_isspace(st, tbl)


@create_unary_identifier_func("StringView.isupper")
@create_unary_identifier_func("isupper")
def isupper_impl(st, tbl):
return _string_view_isupper(st, tbl)


@create_unary_identifier_func("StringView.islower")
@create_unary_identifier_func("islower")
def islower_impl(st, tbl):
return _string_view_islower(st, tbl)


@create_unary_identifier_func("StringView.istitle")
@create_unary_identifier_func("istitle")
def istitle_impl(st, tbl):
return _string_view_istitle(st, tbl)


@cuda_lower(len, MaskedType(string_view))
@cuda_lower(len, MaskedType(udf_string))
vyasr marked this conversation as resolved.
Show resolved Hide resolved
def masked_len_impl(context, builder, sig, args):
ret = cgutils.create_struct_proxy(sig.return_type)(context, builder)
masked_sv_ty = sig.args[0]
Expand Down
9 changes: 8 additions & 1 deletion python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def generic(self, args, kws):
# string_view -> int32
# udf_string -> int32
# literal -> int32
return nb_signature(size_type, args[0])
return nb_signature(size_type, string_view)


def register_stringview_binaryop(op, retty):
Expand Down Expand Up @@ -257,7 +257,14 @@ def resolve_replace(self, mod):
create_identifier_attr(func, udf_string),
)


@cuda_decl_registry.register_attr
class UDFStringAttrs(StringViewAttrs):
key = udf_string


cuda_decl_registry.register_attr(StringViewAttrs)
cuda_decl_registry.register_attr(UDFStringAttrs)

register_stringview_binaryop(operator.eq, types.boolean)
register_stringview_binaryop(operator.ne, types.boolean)
Expand Down
39 changes: 39 additions & 0 deletions python/cudf/cudf/testing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
import numpy as np
import pandas as pd
import pytest
from numba.core.typing import signature as nb_signature
from numba.core.typing.templates import AbstractTemplate
from numba.cuda.cudadecl import registry as cuda_decl_registry
from numba.cuda.cudaimpl import lower as cuda_lower
from pandas import testing as tm

import cudf
from cudf._lib.null_mask import bitmask_allocation_size_bytes
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.core.udf.strings_lowering import cast_string_view_to_udf_string
from cudf.core.udf.strings_typing import StringView, string_view, udf_string
from cudf.utils import dtypes as dtypeutils

supported_numpy_dtypes = [
Expand Down Expand Up @@ -387,3 +393,36 @@ def expect_warning_if(condition, warning=FutureWarning, *args, **kwargs):
yield
else:
yield


def sv_to_udf_str(sv):
"""
Cast a string_view object to a udf_string object

This placeholder function never runs in python
It exists only for numba to have something to replace
with the typing and lowering code below

This is similar conceptually to needing a translation
engine to emit an expression in target language "B" when
there is no equivalent in the source language "A" to
translate from. This function effectively defines the
expression in language "A" and the associated typing
and lowering describe the translation process, despite
the expression having no meaning in language "A"
"""
pass


@cuda_decl_registry.register_global(sv_to_udf_str)
class StringViewToUDFStringDecl(AbstractTemplate):
def generic(self, args, kws):
if isinstance(args[0], StringView) and len(args) == 1:
return nb_signature(udf_string, string_view)


@cuda_lower(sv_to_udf_str, string_view)
def sv_to_udf_str_testing_lowering(context, builder, sig, args):
return cast_string_view_to_udf_string(
context, builder, sig.args[0], sig.return_type, args[0]
)
Loading