Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed Nov 9, 2022
1 parent 837a49c commit 302fe60
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
id_unary_funcs,
int_binary_funcs,
size_type,
string_binary_funcs,
string_return_attrs,
string_view,
udf_string,
)
Expand Down Expand Up @@ -174,7 +174,7 @@ def resolve_valid(self, mod):
create_masked_binary_attr(f"MaskedType.{func}", size_type),
)

for func in string_binary_funcs:
for func in string_return_attrs:
setattr(
MaskedStringViewAttrs,
f"resolve_{func}",
Expand Down
3 changes: 3 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_strip(str_udf_data, strip_char):
def func(row):
Expand All @@ -884,6 +885,7 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_lstrip(str_udf_data, strip_char):
def func(row):
Expand All @@ -892,6 +894,7 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
@pytest.mark.parametrize("strip_char", ["1", "a", "12", " ", "", ".", "@"])
def test_string_udf_rstrip(str_udf_data, strip_char):
def func(row):
Expand Down
4 changes: 2 additions & 2 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def resolve_count(self, mod):
"isnumeric",
"istitle",
]
string_binary_funcs = ["strip", "lstrip", "rstrip"]
string_return_attrs = ["strip", "lstrip", "rstrip"]

for func in bool_binary_funcs:
setattr(
Expand All @@ -238,7 +238,7 @@ def resolve_count(self, mod):
create_binary_attr(func, types.boolean),
)

for func in string_binary_funcs:
for func in string_return_attrs:
setattr(
StringViewAttrs,
f"resolve_{func}",
Expand Down

0 comments on commit 302fe60

Please sign in to comment.