Skip to content

Commit

Permalink
Enable scattering integers into decimal columns (#8225)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller authored May 13, 2021
1 parent fb7cdcd commit 6b92e25
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ cdef _get_np_scalar_from_timedelta64(unique_ptr[scalar]& s):

def as_device_scalar(val, dtype=None):
if dtype:
if isinstance(val, (cudf.Scalar, DeviceScalar)):
if isinstance(val, (cudf.Scalar, DeviceScalar)) and dtype != val.dtype:
raise TypeError("Can't update dtype of existing GPU scalar")
else:
return cudf.Scalar(value=val, dtype=dtype).device_value
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def __setitem__(self, key: Any, value: Any):
nelem = len(key)

if is_scalar(value):
value = self.dtype.type(value) if value is not None else value
value = cudf.Scalar(value, dtype=self.dtype)
else:
if len(value) != nelem:
msg = (
Expand Down
5 changes: 5 additions & 0 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class DecimalColumn(ColumnBase):
def __truediv__(self, other):
return self.binary_operator("div", other)

def __setitem__(self, key, value):
if isinstance(value, np.integer):
value = int(value)
super().__setitem__(key, value)

@classmethod
def from_arrow(cls, data: pa.Array):
dtype = Decimal64Dtype.from_arrow(data.type)
Expand Down
6 changes: 4 additions & 2 deletions python/cudf/cudf/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def __setitem__(self, key, value):
value = to_cudf_compatible_scalar(value)
else:
value = column.as_column(value)

if (
not is_categorical_dtype(self._sr._column.dtype)
not isinstance(
self._sr._column.dtype,
(cudf.Decimal64Dtype, cudf.CategoricalDtype),
)
and hasattr(value, "dtype")
and pd.api.types.is_numeric_dtype(value.dtype)
):
Expand Down
4 changes: 4 additions & 0 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def test_typecast_from_decimal(data, from_dtype, to_dtype):
1,
pa.lib.ArrowInvalid,
),
# We will allow for setting scalars into decimal columns
(["1", "2", "3"], Decimal64Dtype(1, 0), 5, 1, ["1", "5", "3"]),
# But not if it has too many digits to fit the precision
(["1", "2", "3"], Decimal64Dtype(1, 0), 50, 1, pa.lib.ArrowInvalid),
],
)
def test_series_setitem_decimal(args):
Expand Down

0 comments on commit 6b92e25

Please sign in to comment.