Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add istitle to string UDFs #11738

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
isdigit_impl,
islower_impl,
isspace_impl,
istitle_impl,
isupper_impl,
len_impl,
rfind_impl,
Expand Down Expand Up @@ -123,3 +124,4 @@ def masked_unary_func_impl(context, builder, sig, args):
create_masked_unary_identifier_func("MaskedType.islower", islower_impl)
create_masked_unary_identifier_func("MaskedType.isspace", isspace_impl)
create_masked_unary_identifier_func("MaskedType.isdecimal", isdecimal_impl)
create_masked_unary_identifier_func("MaskedType.istitle", istitle_impl)
12 changes: 12 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def str_udf_data():
"cudf",
"cuda",
"gpu",
"This Is A Title",
"This is Not a Title",
"Neither is This a Title",
"NoT a TiTlE",
]
}
)
Expand Down Expand Up @@ -839,6 +843,14 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
def test_string_udf_istitle(str_udf_data):
def func(row):
return row["str_col"].istitle()

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
def test_string_udf_count(str_udf_data, substr):
def func(row):
Expand Down
24 changes: 24 additions & 0 deletions python/strings_udf/cpp/include/cudf/strings/udf/char_types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,30 @@ __device__ inline bool is_lower(cudf::strings::detail::character_flags_table_typ
flags_table, d_str, string_character_types::LOWER, string_character_types::CASE_TYPES);
}

/**
* @brief Returns true if string is in title case
*
* @param tables The char tables required for checking characters
* @param d_str Input string to check
* @return True if string is in title case
*/
__device__ inline bool is_title(cudf::strings::detail::character_flags_table_type* flags_table,
string_view d_str)
{
bool valid = false; // requires one or more cased characters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case that covers the branch where we start with valid and never enter the if statement below? I assume a numeric like "123" or a symbolic string like "^#(" would do the trick.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether it's best to test in the string UDF tests, the cudf tests, or both; up to you.

bool should_be_capitalized = true; // current character should be upper-case
vyasr marked this conversation as resolved.
Show resolved Hide resolved
for (auto const chr : d_str) {
auto const code_point = cudf::strings::detail::utf8_to_codepoint(chr);
auto const flag = code_point <= 0x00FFFF ? flags_table[code_point] : 0;
if (cudf::strings::detail::IS_UPPER_OR_LOWER(flag)) {
if (should_be_capitalized == !cudf::strings::detail::IS_UPPER(flag)) return false;
valid = true;
}
should_be_capitalized = !cudf::strings::detail::IS_UPPER_OR_LOWER(flag);
vyasr marked this conversation as resolved.
Show resolved Hide resolved
}
return valid;
}

} // namespace udf
} // namespace strings
} // namespace cudf
9 changes: 9 additions & 0 deletions python/strings_udf/cpp/src/strings/udf/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ extern "C" __device__ int pyisalpha(bool* nb_retval, void const* str, std::int64
return 0;
}

extern "C" __device__ int pyistitle(bool* nb_retval, void const* str, std::int64_t chars_table)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);

*nb_retval = is_title(
reinterpret_cast<cudf::strings::detail::character_flags_table_type*>(chars_table), *str_view);
return 0;
}

extern "C" __device__ int pycount(int* nb_retval, void const* str, void const* substr)
{
auto str_view = reinterpret_cast<cudf::string_view const*>(str);
Expand Down
1 change: 1 addition & 0 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def resolve_count(self, mod):
"islower",
"isspace",
"isnumeric",
"istitle",
]

for func in bool_binary_funcs:
Expand Down
6 changes: 6 additions & 0 deletions python/strings_udf/strings_udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _declare_binary_func(lhs, rhs, out, name):
_string_view_isspace = _declare_bool_str_int_func("pyisspace")
_string_view_isupper = _declare_bool_str_int_func("pyisupper")
_string_view_islower = _declare_bool_str_int_func("pyislower")
_string_view_istitle = _declare_bool_str_int_func("pyistitle")


_string_view_count = cuda.declare_device(
Expand Down Expand Up @@ -285,3 +286,8 @@ def isupper_impl(st, tbl):
@create_unary_identifier_func("StringView.islower")
def islower_impl(st, tbl):
return _string_view_islower(st, tbl)


@create_unary_identifier_func("StringView.istitle")
def istitle_impl(st, tbl):
return _string_view_istitle(st, tbl)
11 changes: 11 additions & 0 deletions python/strings_udf/strings_udf/tests/test_string_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def data():
"cudf",
"cuda",
"gpu",
"This Is A Title",
"This is Not a Title",
"Neither is This a Title",
"NoT a TiTlE",
]


Expand Down Expand Up @@ -228,6 +232,13 @@ def func(st):
run_udf_test(data, func, "bool")


def test_string_udf_istitle(data):
def func(st):
return st.istitle()

run_udf_test(data, func, "bool")


def test_string_udf_len(data):
def func(st):
return len(st)
Expand Down