Skip to content

Commit

Permalink
Enable scattering scalars into decimal columns (#7899)
Browse files Browse the repository at this point in the history
Closes #7879, adds the ability to coerce an `int` or `Decimal` to a different `Decimal64Dtype` where possible and begins to plumb `pa.scalar` into some useful places within `cudf.Scalar`

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Keith Kraus (https://github.com/kkraus14)
  - Paul Taylor (https://github.com/trxcllnt)

URL: #7899
  • Loading branch information
brandon-b-miller authored Apr 16, 2021
1 parent 98711da commit bc422fc
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 48 deletions.
53 changes: 23 additions & 30 deletions python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import decimal

import numpy as np
import pyarrow as pa

from cudf._lib.scalar import DeviceScalar, _is_null_host_scalar
from cudf.core.column.column import ColumnBase
Expand Down Expand Up @@ -114,44 +115,36 @@ def _device_value_to_host(self):
self._host_value = self._device_value._to_host_scalar()

def _preprocess_host_value(self, value, dtype):
if isinstance(dtype, Decimal64Dtype):
# TODO: Support coercion from decimal.Decimal to different dtype
# TODO: Support coercion from integer to Decimal64Dtype
raise NotImplementedError(
"dtype as cudf.Decimal64Dtype is not supported. Pass a "
"decimal.Decimal to construct a DecimalScalar."
)
if isinstance(value, decimal.Decimal) and dtype is not None:
raise TypeError(f"Can not coerce decimal to {dtype}")

value = to_cudf_compatible_scalar(value, dtype=dtype)
valid = not _is_null_host_scalar(value)

if isinstance(value, decimal.Decimal):
# 0.0042 -> Decimal64Dtype(2, 4)
if isinstance(dtype, Decimal64Dtype):
value = pa.scalar(
value, type=pa.decimal128(dtype.precision, dtype.scale)
).as_py()
if isinstance(value, decimal.Decimal) and dtype is None:
dtype = Decimal64Dtype._from_decimal(value)

else:
if dtype is None:
if not valid:
if isinstance(value, (np.datetime64, np.timedelta64)):
unit, _ = np.datetime_data(value)
if unit == "generic":
raise TypeError(
"Cant convert generic NaT to null scalar"
)
else:
dtype = value.dtype
else:
value = to_cudf_compatible_scalar(value, dtype=dtype)

if dtype is None:
if not valid:
if isinstance(value, (np.datetime64, np.timedelta64)):
unit, _ = np.datetime_data(value)
if unit == "generic":
raise TypeError(
"dtype required when constructing a null scalar"
"Cant convert generic NaT to null scalar"
)
else:
dtype = value.dtype
else:
dtype = value.dtype
dtype = np.dtype(dtype)
raise TypeError(
"dtype required when constructing a null scalar"
)
else:
dtype = value.dtype

# temporary
dtype = np.dtype("object") if dtype.char == "U" else dtype
if not isinstance(dtype, Decimal64Dtype):
dtype = np.dtype(dtype)

if not valid:
value = NA
Expand Down
86 changes: 86 additions & 0 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,89 @@ def test_typecast_from_decimal(data, from_dtype, to_dtype):
expected = cudf.Series(NumericalColumn.from_arrow(pa_arr))

assert_eq(got, expected)
assert_eq(got.dtype, expected.dtype)


def _decimal_series(input, dtype):
return cudf.Series(
[x if x is None else Decimal(x) for x in input], dtype=dtype,
)


@pytest.mark.parametrize(
"args",
[
# scatter to a single index
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal(5),
1,
["1", "5", "3"],
),
(
["1.5", "2.5", "3.5"],
Decimal64Dtype(2, 1),
Decimal("5.5"),
1,
["1.5", "5.5", "3.5"],
),
(
["1.0042", "2.0042", "3.0042"],
Decimal64Dtype(5, 4),
Decimal("5.0042"),
1,
["1.0042", "5.0042", "3.0042"],
),
# scatter via boolmask
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal(5),
cudf.Series([True, False, True]),
["5", "2", "5"],
),
(
["1.5", "2.5", "3.5"],
Decimal64Dtype(2, 1),
Decimal("5.5"),
cudf.Series([True, True, True]),
["5.5", "5.5", "5.5"],
),
(
["1.0042", "2.0042", "3.0042"],
Decimal64Dtype(5, 4),
Decimal("5.0042"),
cudf.Series([False, False, True]),
["1.0042", "2.0042", "5.0042"],
),
# We will allow assigning a decimal with less precision
(
["1.00", "2.00", "3.00"],
Decimal64Dtype(3, 2),
Decimal(5),
1,
["1.00", "5.00", "3.00"],
),
# But not truncation
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal("5.5"),
1,
pa.lib.ArrowInvalid,
),
],
)
def test_series_setitem_decimal(args):
data, dtype, item, to, expect = args
data = _decimal_series(data, dtype)

if expect is pa.lib.ArrowInvalid:
with pytest.raises(expect):
data[to] = item
return
else:
expect = _decimal_series(expect, dtype)
data[to] = item
assert_eq(data, expect)
45 changes: 28 additions & 17 deletions python/cudf/cudf/tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import cudf
Expand Down Expand Up @@ -194,14 +195,13 @@ def test_scalar_roundtrip(value):
+ TEST_DECIMAL_TYPES,
)
def test_null_scalar(dtype):
if isinstance(dtype, cudf.Decimal64Dtype):
with pytest.raises(NotImplementedError):
s = cudf.Scalar(None, dtype=dtype)
return

s = cudf.Scalar(None, dtype=dtype)
assert s.value is cudf.NA
assert s.dtype == np.dtype(dtype)
assert s.dtype == (
np.dtype(dtype)
if not isinstance(dtype, cudf.Decimal64Dtype)
else dtype
)
assert s.is_valid() is False


Expand Down Expand Up @@ -234,25 +234,36 @@ def test_generic_null_scalar_construction_fails(value):


@pytest.mark.parametrize(
"dtype",
NUMERIC_TYPES
+ DATETIME_TYPES
+ TIMEDELTA_TYPES
+ ["object"]
+ TEST_DECIMAL_TYPES,
"dtype", NUMERIC_TYPES + DATETIME_TYPES + TIMEDELTA_TYPES + ["object"]
)
def test_scalar_dtype_and_validity(dtype):
if isinstance(dtype, cudf.Decimal64Dtype):
with pytest.raises(NotImplementedError):
s = cudf.Scalar(None, dtype=dtype)
return

s = cudf.Scalar(1, dtype=dtype)

assert s.dtype == np.dtype(dtype)
assert s.is_valid() is True


@pytest.mark.parametrize(
"slr,dtype,expect",
[
(1, cudf.Decimal64Dtype(1, 0), Decimal("1")),
(Decimal(1), cudf.Decimal64Dtype(1, 0), Decimal("1")),
(Decimal("1.1"), cudf.Decimal64Dtype(2, 1), Decimal("1.1")),
(Decimal("1.1"), cudf.Decimal64Dtype(4, 3), Decimal("1.100")),
(Decimal("1.11"), cudf.Decimal64Dtype(2, 2), pa.lib.ArrowInvalid),
],
)
def test_scalar_dtype_and_validity_decimal(slr, dtype, expect):
if expect is pa.lib.ArrowInvalid:
with pytest.raises(expect):
cudf.Scalar(slr, dtype=dtype)
return
else:
result = cudf.Scalar(slr, dtype=dtype)
assert result.dtype == dtype
assert result.is_valid


@pytest.mark.parametrize(
"value",
[
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def get_appropriate_dispatched_func(
def _cast_to_appropriate_cudf_type(val, index=None):
# Handle scalar
if val.ndim == 0:
return cudf.Scalar(val).value
return to_cudf_compatible_scalar(val)
# 1D array
elif (val.ndim == 1) or (val.ndim == 2 and val.shape[1] == 1):
# if index is not None and is of a different length
Expand Down

0 comments on commit bc422fc

Please sign in to comment.