From 3a2609ba68d7511bc53451ef155ff87a0948ff9a Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 23 Mar 2023 16:18:56 +0000 Subject: [PATCH] Fix `__setitem__` on string columns when the scalar value ends in a null byte (#12991) Since numpy strings are fixed width and use a null byte as an indicator of the end of the string, there is no way to distinguish between numpy.str_("abc\x00").item() and numpy.str_("abc").item(). This has consequences for scalar preprocessing we do when constructing a cudf.Scalar, since that usually goes through numpy.astype(...).item(). So, when preprocessing as scalar, if we notice it is a string with trailing null bytes, keep it as is. Closes #12990. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/12991 --- python/cudf/cudf/tests/test_setitem.py | 15 +++++++++++++++ python/cudf/cudf/utils/dtypes.py | 9 +++++++++ 2 files changed, 24 insertions(+) diff --git a/python/cudf/cudf/tests/test_setitem.py b/python/cudf/cudf/tests/test_setitem.py index 4d9ffc7cd81..dd82a9244b6 100644 --- a/python/cudf/cudf/tests/test_setitem.py +++ b/python/cudf/cudf/tests/test_setitem.py @@ -353,3 +353,18 @@ def test_scatter_by_slice_with_start_and_step(): target[1::2] = source ctarget[1::2] = csource assert_eq(target, ctarget) + + +@pytest.mark.parametrize("n", [1, 3]) +def test_setitem_str_trailing_null(n): + trailing_nulls = "\x00" * n + s = cudf.Series(["a", "b", "c" + trailing_nulls]) + assert s[2] == "c" + trailing_nulls + s[0] = "a" + trailing_nulls + assert s[0] == "a" + trailing_nulls + s[1] = trailing_nulls + assert s[1] == trailing_nulls + s[0] = "" + assert s[0] == "" + s[0] = "\x00" + assert s[0] == "\x00" diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index acf00b3a3d5..2484003bd38 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -260,6 +260,15 @@ def to_cudf_compatible_scalar(val, dtype=None): ) or cudf.api.types.is_string_dtype(dtype): dtype = "str" + if isinstance(val, str) and val.endswith("\x00"): + # Numpy string dtypes are fixed width and use NULL to + # indicate the end of the string, so they cannot + # distinguish between "abc\x00" and "abc". + # https://github.com/numpy/numpy/issues/20118 + # In this case, don't try going through numpy and just use + # the string value directly (cudf.DeviceScalar will DTRT) + return val + if isinstance(val, datetime.datetime): val = np.datetime64(val) elif isinstance(val, datetime.timedelta):