Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable scattering scalars into decimal columns #7899

Merged
merged 10 commits into from
Apr 16, 2021
53 changes: 23 additions & 30 deletions python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import decimal

import numpy as np
import pyarrow as pa

from cudf._lib.scalar import DeviceScalar, _is_null_host_scalar
from cudf.core.column.column import ColumnBase
Expand Down Expand Up @@ -114,44 +115,36 @@ def _device_value_to_host(self):
self._host_value = self._device_value._to_host_scalar()

def _preprocess_host_value(self, value, dtype):
if isinstance(dtype, Decimal64Dtype):
# TODO: Support coercion from decimal.Decimal to different dtype
# TODO: Support coercion from integer to Decimal64Dtype
raise NotImplementedError(
"dtype as cudf.Decimal64Dtype is not supported. Pass a "
"decimal.Decimal to construct a DecimalScalar."
)
if isinstance(value, decimal.Decimal) and dtype is not None:
raise TypeError(f"Can not coerce decimal to {dtype}")

value = to_cudf_compatible_scalar(value, dtype=dtype)
valid = not _is_null_host_scalar(value)

if isinstance(value, decimal.Decimal):
# 0.0042 -> Decimal64Dtype(2, 4)
if isinstance(dtype, Decimal64Dtype):
value = pa.scalar(
value, type=pa.decimal128(dtype.precision, dtype.scale)
).as_py()
if isinstance(value, decimal.Decimal) and dtype is None:
dtype = Decimal64Dtype._from_decimal(value)

else:
if dtype is None:
if not valid:
if isinstance(value, (np.datetime64, np.timedelta64)):
unit, _ = np.datetime_data(value)
if unit == "generic":
raise TypeError(
"Cant convert generic NaT to null scalar"
)
else:
dtype = value.dtype
else:
value = to_cudf_compatible_scalar(value, dtype=dtype)

if dtype is None:
if not valid:
if isinstance(value, (np.datetime64, np.timedelta64)):
unit, _ = np.datetime_data(value)
if unit == "generic":
raise TypeError(
"dtype required when constructing a null scalar"
"Cant convert generic NaT to null scalar"
)
else:
dtype = value.dtype
else:
dtype = value.dtype
dtype = np.dtype(dtype)
raise TypeError(
"dtype required when constructing a null scalar"
)
else:
dtype = value.dtype

# temporary
dtype = np.dtype("object") if dtype.char == "U" else dtype
if not isinstance(dtype, Decimal64Dtype):
dtype = np.dtype(dtype)

if not valid:
value = NA
Expand Down
86 changes: 86 additions & 0 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,89 @@ def test_typecast_from_decimal(data, from_dtype, to_dtype):
expected = cudf.Series(NumericalColumn.from_arrow(pa_arr))

assert_eq(got, expected)
assert_eq(got.dtype, expected.dtype)


def _decimal_series(input, dtype):
return cudf.Series(
[x if x is None else Decimal(x) for x in input], dtype=dtype,
)


@pytest.mark.parametrize(
"args",
[
# scatter to a single index
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal(5),
1,
["1", "5", "3"],
),
(
["1.5", "2.5", "3.5"],
Decimal64Dtype(2, 1),
Decimal("5.5"),
1,
["1.5", "5.5", "3.5"],
),
(
["1.0042", "2.0042", "3.0042"],
Decimal64Dtype(5, 4),
Decimal("5.0042"),
1,
["1.0042", "5.0042", "3.0042"],
),
# scatter via boolmask
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal(5),
cudf.Series([True, False, True]),
["5", "2", "5"],
),
(
["1.5", "2.5", "3.5"],
Decimal64Dtype(2, 1),
Decimal("5.5"),
cudf.Series([True, True, True]),
["5.5", "5.5", "5.5"],
),
(
["1.0042", "2.0042", "3.0042"],
Decimal64Dtype(5, 4),
Decimal("5.0042"),
cudf.Series([False, False, True]),
["1.0042", "2.0042", "5.0042"],
),
# We will allow assigning a decimal with less precision
(
["1.00", "2.00", "3.00"],
Decimal64Dtype(3, 2),
Decimal(5),
1,
["1.00", "5.00", "3.00"],
),
# But not truncation
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal("5.5"),
1,
pa.lib.ArrowInvalid,
),
],
)
def test_series_setitem_decimal(args):
data, dtype, item, to, expect = args
data = _decimal_series(data, dtype)

if expect is pa.lib.ArrowInvalid:
with pytest.raises(expect):
data[to] = item
return
else:
expect = _decimal_series(expect, dtype)
data[to] = item
assert_eq(data, expect)
45 changes: 28 additions & 17 deletions python/cudf/cudf/tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import cudf
Expand Down Expand Up @@ -194,14 +195,13 @@ def test_scalar_roundtrip(value):
+ TEST_DECIMAL_TYPES,
)
def test_null_scalar(dtype):
if isinstance(dtype, cudf.Decimal64Dtype):
with pytest.raises(NotImplementedError):
s = cudf.Scalar(None, dtype=dtype)
return

s = cudf.Scalar(None, dtype=dtype)
assert s.value is cudf.NA
assert s.dtype == np.dtype(dtype)
assert s.dtype == (
np.dtype(dtype)
if not isinstance(dtype, cudf.Decimal64Dtype)
else dtype
)
assert s.is_valid() is False


Expand Down Expand Up @@ -234,25 +234,36 @@ def test_generic_null_scalar_construction_fails(value):


@pytest.mark.parametrize(
"dtype",
NUMERIC_TYPES
+ DATETIME_TYPES
+ TIMEDELTA_TYPES
+ ["object"]
+ TEST_DECIMAL_TYPES,
"dtype", NUMERIC_TYPES + DATETIME_TYPES + TIMEDELTA_TYPES + ["object"]
)
def test_scalar_dtype_and_validity(dtype):
if isinstance(dtype, cudf.Decimal64Dtype):
with pytest.raises(NotImplementedError):
s = cudf.Scalar(None, dtype=dtype)
return

s = cudf.Scalar(1, dtype=dtype)

assert s.dtype == np.dtype(dtype)
assert s.is_valid() is True


@pytest.mark.parametrize(
"slr,dtype,expect",
[
(1, cudf.Decimal64Dtype(1, 0), Decimal("1")),
(Decimal(1), cudf.Decimal64Dtype(1, 0), Decimal("1")),
(Decimal("1.1"), cudf.Decimal64Dtype(2, 1), Decimal("1.1")),
(Decimal("1.1"), cudf.Decimal64Dtype(4, 3), Decimal("1.100")),
(Decimal("1.11"), cudf.Decimal64Dtype(2, 2), pa.lib.ArrowInvalid),
],
)
def test_scalar_dtype_and_validity_decimal(slr, dtype, expect):
if expect is pa.lib.ArrowInvalid:
with pytest.raises(expect):
cudf.Scalar(slr, dtype=dtype)
return
else:
result = cudf.Scalar(slr, dtype=dtype)
assert result.dtype == dtype
assert result.is_valid


@pytest.mark.parametrize(
"value",
[
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def get_appropriate_dispatched_func(
def _cast_to_appropriate_cudf_type(val, index=None):
# Handle scalar
if val.ndim == 0:
return cudf.Scalar(val).value
return to_cudf_compatible_scalar(val)
# 1D array
elif (val.ndim == 1) or (val.ndim == 2 and val.shape[1] == 1):
# if index is not None and is of a different length
Expand Down