Skip to content

Commit

Permalink
Fix __setitem__ on string columns when the scalar value ends in a n…
Browse files Browse the repository at this point in the history
…ull byte (#12991)

Since numpy strings are fixed width and use a null byte as an
indicator of the end of the string, there is no way to distinguish
between numpy.str_("abc\x00").item() and numpy.str_("abc").item().
This has consequences for scalar preprocessing we do when constructing
a cudf.Scalar, since that usually goes through
numpy.astype(...).item(). So, when preprocessing as scalar, if we
notice it is a string with trailing null bytes, keep it as is.

Closes #12990.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #12991
  • Loading branch information
wence- authored Mar 23, 2023
1 parent 6966fd5 commit 3a2609b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/cudf/cudf/tests/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,18 @@ def test_scatter_by_slice_with_start_and_step():
target[1::2] = source
ctarget[1::2] = csource
assert_eq(target, ctarget)


@pytest.mark.parametrize("n", [1, 3])
def test_setitem_str_trailing_null(n):
trailing_nulls = "\x00" * n
s = cudf.Series(["a", "b", "c" + trailing_nulls])
assert s[2] == "c" + trailing_nulls
s[0] = "a" + trailing_nulls
assert s[0] == "a" + trailing_nulls
s[1] = trailing_nulls
assert s[1] == trailing_nulls
s[0] = ""
assert s[0] == ""
s[0] = "\x00"
assert s[0] == "\x00"
9 changes: 9 additions & 0 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@ def to_cudf_compatible_scalar(val, dtype=None):
) or cudf.api.types.is_string_dtype(dtype):
dtype = "str"

if isinstance(val, str) and val.endswith("\x00"):
# Numpy string dtypes are fixed width and use NULL to
# indicate the end of the string, so they cannot
# distinguish between "abc\x00" and "abc".
# https://github.com/numpy/numpy/issues/20118
# In this case, don't try going through numpy and just use
# the string value directly (cudf.DeviceScalar will DTRT)
return val

if isinstance(val, datetime.datetime):
val = np.datetime64(val)
elif isinstance(val, datetime.timedelta):
Expand Down

0 comments on commit 3a2609b

Please sign in to comment.