Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow casting from UDFString back to StringView to call methods in strings_udf #12363

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
477155d
initially working up through count, work to do
brandon-b-miller Dec 12, 2022
08e190e
replace standalone function works in strings_udf
brandon-b-miller Dec 13, 2022
c1bd41c
strings_udf tests pass
brandon-b-miller Dec 15, 2022
3f043d6
all tests pass
brandon-b-miller Dec 16, 2022
c9fa690
merge latest and resolve conflicts
brandon-b-miller Jan 5, 2023
70109dc
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 5, 2023
8c8eb85
update copyright years
brandon-b-miller Jan 5, 2023
6049d53
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 9, 2023
eb44719
make things much easier
brandon-b-miller Jan 9, 2023
59b42ea
refactor
brandon-b-miller Jan 9, 2023
599033f
continue reverting changes
brandon-b-miller Jan 9, 2023
74d6c61
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 17, 2023
d0f4e3b
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
galipremsagar Jan 25, 2023
6b3b085
Merge branch 'branch-23.02' into fix-stringudf-chained-ops
brandon-b-miller Jan 30, 2023
b163c1b
cast in len
brandon-b-miller Jan 30, 2023
21a2121
small typing bug
brandon-b-miller Jan 30, 2023
8a55a26
remove duplicate code
brandon-b-miller Jan 30, 2023
0d01928
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
brandon-b-miller Feb 2, 2023
4043691
merge latest, resolve conflicts, pass tests, refactor
brandon-b-miller Feb 22, 2023
26a6f88
remove old file
brandon-b-miller Feb 22, 2023
888f4dd
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
brandon-b-miller Feb 28, 2023
254a583
address reviews
brandon-b-miller Mar 1, 2023
0840cdc
add docs to sv_to_udf_str
brandon-b-miller Mar 1, 2023
c9ff968
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
vyasr Mar 6, 2023
e6ae995
Merge branch 'branch-23.04' into fix-stringudf-chained-ops
vyasr Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.

import operator

Expand Down Expand Up @@ -36,6 +36,7 @@


@cuda_lower(len, MaskedType(string_view))
@cuda_lower(len, MaskedType(udf_string))
vyasr marked this conversation as resolved.
Show resolved Hide resolved
def masked_len_impl(context, builder, sig, args):
ret = cgutils.create_struct_proxy(sig.return_type)(context, builder)
masked_sv_ty = sig.args[0]
Expand Down
18 changes: 15 additions & 3 deletions python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.

import operator

Expand All @@ -9,6 +9,7 @@

from strings_udf._typing import (
StringView,
UDFString,
bool_binary_funcs,
id_unary_funcs,
int_binary_funcs,
Expand All @@ -25,11 +26,13 @@

masked_typing.MASKED_INIT_MAP[types.pyobject] = string_view
masked_typing.MASKED_INIT_MAP[string_view] = string_view
masked_typing.MASKED_INIT_MAP[udf_string] = udf_string


def _is_valid_string_arg(ty):
return (
isinstance(ty, MaskedType) and isinstance(ty.value_type, StringView)
isinstance(ty, MaskedType)
and isinstance(ty.value_type, (StringView, UDFString))
) or isinstance(ty, types.StringLiteral)


Expand All @@ -53,7 +56,7 @@ class MaskedStringFunction(AbstractTemplate):
@register_string_function(len)
def len_typing(self, args, kws):
if isinstance(args[0], MaskedType) and isinstance(
args[0].value_type, StringView
args[0].value_type, (StringView, UDFString)
):
return nb_signature(MaskedType(size_type), args[0])
elif isinstance(args[0], types.StringLiteral) and len(args) == 1:
Expand Down Expand Up @@ -223,4 +226,13 @@ def resolve_valid(self, mod):
create_masked_unary_attr(f"MaskedType.{func}", udf_string),
)


class MaskedUDFStringAttrs(MaskedStringViewAttrs):
key = MaskedType(udf_string)

def resolve_value(self, mod):
return udf_string


cuda_decl_registry.register_attr(MaskedStringViewAttrs)
cuda_decl_registry.register_attr(MaskedUDFStringAttrs)
35 changes: 34 additions & 1 deletion python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
import math
import operator

Expand All @@ -15,6 +15,7 @@
comparison_ops,
unary_ops,
)
from cudf.core.udf.api import Masked
from cudf.core.udf.utils import precompiled
from cudf.testing._utils import (
_decimal_series,
Expand Down Expand Up @@ -72,6 +73,38 @@ def run_masked_udf_test(func, data, args=(), **kwargs):
assert_eq(expect, obtain, **kwargs)


def run_masked_string_udf_test(func, data, args=(), **kwargs):
from strings_udf._typing import sv_to_udf_str

gdf = data
pdf = data.to_pandas(nullable=True)

def row_wrapper(row):
st = row["str_col"]
return func(st)

expect = pdf.apply(row_wrapper, args=args, axis=1)

func = cuda.jit(device=True)(func)
obtain = gdf.apply(row_wrapper, args=args, axis=1)
assert_eq(expect, obtain, **kwargs)

# strings that come directly from input columns are backed by
# MaskedType(string_view) types. But new strings that are returned
# from functions or operators are backed by MaskedType(udf_string)
# types. We need to make sure all of our methods work on both kind
# of MaskedType. This function promotes the former to the latter
# prior to running the input function
def udf_string_wrapper(row):
masked_udf_str = Masked(
sv_to_udf_str(row["str_col"].value), row["str_col"].valid
)
return func(masked_udf_str)

obtain = gdf.apply(udf_string_wrapper, args=args, axis=1)
assert_eq(expect, obtain, **kwargs)


def run_masked_udf_series(func, data, args=(), **kwargs):
gsr = data
psr = data.to_pandas(nullable=True)
Expand Down
13 changes: 12 additions & 1 deletion python/strings_udf/cpp/src/strings/udf/shim.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -231,6 +231,17 @@ extern "C" __device__ int udf_string_from_string_view(int* nb_retbal,
return 0;
}

extern "C" __device__ int string_view_from_udf_string(int* nb_retval,
void const* udf_str,
void* str)
{
auto udf_str_ptr = reinterpret_cast<udf_string const*>(udf_str);
auto sv_ptr = new (str) cudf::string_view;
*sv_ptr = cudf::string_view(*udf_str_ptr);

return 0;
}

extern "C" __device__ int strip(int* nb_retval,
void* udf_str,
void* const* to_strip,
Expand Down
28 changes: 25 additions & 3 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.

import operator

Expand Down Expand Up @@ -39,7 +39,6 @@ def return_type(self):


class StringView(types.Type):

np_dtype = np.dtype("object")

def __init__(self):
Expand Down Expand Up @@ -120,6 +119,18 @@ def prepare_args(self, ty, val, **kwargs):
str_view_arg_handler = StrViewArgHandler()


# for use in testing only
def sv_to_udf_str(sv):
vyasr marked this conversation as resolved.
Show resolved Hide resolved
pass


@cuda_decl_registry.register_global(sv_to_udf_str)
class StringViewToUDFStringDecl(AbstractTemplate):
def generic(self, args, kws):
if isinstance(args[0], StringView) and len(args) == 1:
return nb_signature(udf_string, string_view)


# String functions
@cuda_decl_registry.register_global(len)
class StringLength(AbstractTemplate):
Expand Down Expand Up @@ -242,6 +253,7 @@ def resolve_replace(self, mod):
create_binary_attr(func, types.boolean),
)


for func in string_return_attrs:
setattr(
StringViewAttrs,
Expand All @@ -252,9 +264,12 @@ def resolve_replace(self, mod):

for func in int_binary_funcs:
setattr(
StringViewAttrs, f"resolve_{func}", create_binary_attr(func, size_type)
StringViewAttrs,
f"resolve_{func}",
create_binary_attr(func, size_type),
)


for func in id_unary_funcs:
setattr(
StringViewAttrs,
Expand All @@ -269,7 +284,14 @@ def resolve_replace(self, mod):
create_identifier_attr(func, udf_string),
)


@cuda_decl_registry.register_attr
class UDFStringAttrs(StringViewAttrs):
key = udf_string


cuda_decl_registry.register_attr(StringViewAttrs)
cuda_decl_registry.register_attr(UDFStringAttrs)

register_stringview_binaryop(operator.eq, types.boolean)
register_stringview_binaryop(operator.ne, types.boolean)
Expand Down
Loading