From 6b92e2534d729d18af399856b2b845b6584b8ee2 Mon Sep 17 00:00:00 2001 From: brandon-b-miller <53796099+brandon-b-miller@users.noreply.github.com> Date: Thu, 13 May 2021 10:38:05 -0500 Subject: [PATCH] Enable scattering integers into decimal columns (#8225) Closes https://github.com/rapidsai/cudf/issues/8157 Authors: - https://github.com/brandon-b-miller Approvers: - Christopher Harris (https://github.com/cwharris) - Ashwin Srinath (https://github.com/shwina) - Keith Kraus (https://github.com/kkraus14) URL: https://github.com/rapidsai/cudf/pull/8225 --- python/cudf/cudf/_lib/scalar.pyx | 2 +- python/cudf/cudf/core/column/column.py | 2 +- python/cudf/cudf/core/column/decimal.py | 5 +++++ python/cudf/cudf/core/indexing.py | 6 ++++-- python/cudf/cudf/tests/test_decimal.py | 4 ++++ 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/cudf/cudf/_lib/scalar.pyx b/python/cudf/cudf/_lib/scalar.pyx index b31f0675422..9f8a8ee6b1e 100644 --- a/python/cudf/cudf/_lib/scalar.pyx +++ b/python/cudf/cudf/_lib/scalar.pyx @@ -405,7 +405,7 @@ cdef _get_np_scalar_from_timedelta64(unique_ptr[scalar]& s): def as_device_scalar(val, dtype=None): if dtype: - if isinstance(val, (cudf.Scalar, DeviceScalar)): + if isinstance(val, (cudf.Scalar, DeviceScalar)) and dtype != val.dtype: raise TypeError("Can't update dtype of existing GPU scalar") else: return cudf.Scalar(value=val, dtype=dtype).device_value diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 0f039b137bc..42bfce0408c 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -681,7 +681,7 @@ def __setitem__(self, key: Any, value: Any): nelem = len(key) if is_scalar(value): - value = self.dtype.type(value) if value is not None else value + value = cudf.Scalar(value, dtype=self.dtype) else: if len(value) != nelem: msg = ( diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 868fde17d87..e3d88424b8a 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -29,6 +29,11 @@ class DecimalColumn(ColumnBase): def __truediv__(self, other): return self.binary_operator("div", other) + def __setitem__(self, key, value): + if isinstance(value, np.integer): + value = int(value) + super().__setitem__(key, value) + @classmethod def from_arrow(cls, data: pa.Array): dtype = Decimal64Dtype.from_arrow(data.type) diff --git a/python/cudf/cudf/core/indexing.py b/python/cudf/cudf/core/indexing.py index a732abc0705..7de1aaf9726 100755 --- a/python/cudf/cudf/core/indexing.py +++ b/python/cudf/cudf/core/indexing.py @@ -101,9 +101,11 @@ def __setitem__(self, key, value): value = to_cudf_compatible_scalar(value) else: value = column.as_column(value) - if ( - not is_categorical_dtype(self._sr._column.dtype) + not isinstance( + self._sr._column.dtype, + (cudf.Decimal64Dtype, cudf.CategoricalDtype), + ) and hasattr(value, "dtype") and pd.api.types.is_numeric_dtype(value.dtype) ): diff --git a/python/cudf/cudf/tests/test_decimal.py b/python/cudf/cudf/tests/test_decimal.py index 111f973a78b..073a8e443c7 100644 --- a/python/cudf/cudf/tests/test_decimal.py +++ b/python/cudf/cudf/tests/test_decimal.py @@ -268,6 +268,10 @@ def test_typecast_from_decimal(data, from_dtype, to_dtype): 1, pa.lib.ArrowInvalid, ), + # We will allow for setting scalars into decimal columns + (["1", "2", "3"], Decimal64Dtype(1, 0), 5, 1, ["1", "5", "3"]), + # But not if it has too many digits to fit the precision + (["1", "2", "3"], Decimal64Dtype(1, 0), 50, 1, pa.lib.ArrowInvalid), ], ) def test_series_setitem_decimal(args):