Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
Fix type casting in Series.__setitem__
Browse files Browse the repository at this point in the history
To mimic pandas, we must upcast a column to the numpy result_type of
the column itself and the input value dtype. This previously occurred
in all relevant cases except when the index provided to __setitem__
was a single integer (originally introduced in rapidsai#2442). Closes rapidsai#11901.
  • Loading branch information
wence- committed Oct 11, 2022
1 parent 9ba6142 commit cdbce51
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
11 changes: 6 additions & 5 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ def __setitem__(self, key, value):
and _is_non_decimal_numeric_dtype(value.dtype)
):
# normalize types if necessary:
if not is_integer(key):
to_dtype = np.result_type(
value.dtype, self._frame._column.dtype
)
value = value.astype(to_dtype)
# In contrast to Column.__setitem__ (which downcasts the value to
# the dtype of the column) here we upcast the series to the
# larger data type mimicing pandas
to_dtype = np.result_type(value.dtype, self._frame._column.dtype)
value = value.astype(to_dtype)
if to_dtype != self._frame._column.dtype:
self._frame._column._mimic_inplace(
self._frame._column.astype(to_dtype), inplace=True
)
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/tests/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,19 @@ def test_series_slice_setitem_struct():
actual[0:3] = cudf.Scalar({"a": {"b": 5050}, "b": 101})

assert_eq(actual, expected)


@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64])
@pytest.mark.parametrize("indices", [0, [1, 2]])
def test_series_setitem_upcasting(dtype, indices):
sr = pd.Series([0, 0, 0], dtype=dtype)
cr = cudf.from_pandas(sr)
assert_eq(sr.values, cr.values)
new_value = np.float64(10.5)
col_ref = cr._column
sr[indices] = new_value
cr[indices] = new_value
assert_eq(sr.values, cr.values)
if dtype == np.float64:
# no-op type cast should not modify backing column
assert col_ref == cr._column

0 comments on commit cdbce51

Please sign in to comment.