Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable scattering scalars into decimal columns #7899

Merged
merged 10 commits into from
Apr 16, 2021
Merged
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ cdef class DeviceScalar:
dtype : dtype
A NumPy dtype.
"""
self._dtype = dtype if dtype.kind != 'U' else np.dtype('object')
self._dtype = dtype if np.dtype(dtype).kind != 'U' else np.dtype('object')
self._set_value(value, self._dtype)

def _set_value(self, value, dtype):
Expand Down
68 changes: 33 additions & 35 deletions python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import decimal

import numpy as np
import pyarrow as pa

from cudf._lib.scalar import DeviceScalar, _is_null_host_scalar
from cudf.core.column.column import ColumnBase
Expand All @@ -10,6 +11,7 @@
from cudf.core.series import Series
from cudf.utils.dtypes import (
get_allowed_combinations_for_operator,
is_decimal_dtype,
to_cudf_compatible_scalar,
)

Expand Down Expand Up @@ -114,49 +116,45 @@ def _device_value_to_host(self):
self._host_value = self._device_value._to_host_scalar()

def _preprocess_host_value(self, value, dtype):
if isinstance(dtype, Decimal64Dtype):
# TODO: Support coercion from decimal.Decimal to different dtype
# TODO: Support coercion from integer to Decimal64Dtype
raise NotImplementedError(
"dtype as cudf.Decimal64Dtype is not supported. Pass a "
"decimal.Decimal to construct a DecimalScalar."
)
if isinstance(value, decimal.Decimal) and dtype is not None:
raise TypeError(f"Can not coerce decimal to {dtype}")

value = to_cudf_compatible_scalar(value, dtype=dtype)
valid = not _is_null_host_scalar(value)

if isinstance(value, decimal.Decimal):
# 0.0042 -> Decimal64Dtype(2, 4)
dtype = Decimal64Dtype._from_decimal(value)

if value is NA and dtype is not None:
return NA, dtype
elif type(value) is np.bool_:
value = bool(value)
elif isinstance(value, np.datetime64) and dtype is not None:
value = value.astype(dtype)
pa_scalar = pa.scalar(value)
valid = pa_scalar.is_valid

# decimal handling
if isinstance(value, decimal.Decimal) or is_decimal_dtype(dtype):
if dtype is None:
# value must be a decimal, derive the dtype
dtype = Decimal64Dtype._from_decimal(value)

# arrow coerces a decimal to a difference precision/scale
# or an int to a decimal with certain precision/scale
# and errors if the incoming object cannot be coerced
# it also cleanly handles None
value = pa.scalar(
value, type=pa.decimal128(dtype.precision, dtype.scale)
).as_py()
else:
value = to_cudf_compatible_scalar(value, dtype=dtype)
if dtype is None:
if not valid:
if isinstance(value, (np.datetime64, np.timedelta64)):
unit, _ = np.datetime_data(value)
if unit == "generic":
raise TypeError(
"Cant convert generic NaT to null scalar"
)
else:
dtype = value.dtype
else:
raise TypeError(
"dtype required when constructing a null scalar"
)
if not valid and not isinstance(
pa_scalar.type, (pa.TimestampType, pa.DurationType)
):
raise TypeError(
"dtype required when constructing a null scalar"
)
else:
dtype = value.dtype
dtype = np.dtype(dtype)

dtype = np.dtype(dtype)
# temporary
dtype = np.dtype("object") if dtype.char == "U" else dtype

if not valid:
value = NA

return value, dtype
return value if valid else NA, dtype

def _sync(self):
"""
Expand Down
86 changes: 86 additions & 0 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,89 @@ def test_typecast_from_decimal(data, from_dtype, to_dtype):
expected = cudf.Series(NumericalColumn.from_arrow(pa_arr))

assert_eq(got, expected)
assert_eq(got.dtype, expected.dtype)


def _decimal_series(input, dtype):
return cudf.Series(
[x if x is None else Decimal(x) for x in input], dtype=dtype,
)


@pytest.mark.parametrize(
"args",
[
# scatter to a single index
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal(5),
1,
["1", "5", "3"],
),
(
["1.5", "2.5", "3.5"],
Decimal64Dtype(2, 1),
Decimal("5.5"),
1,
["1.5", "5.5", "3.5"],
),
(
["1.0042", "2.0042", "3.0042"],
Decimal64Dtype(5, 4),
Decimal("5.0042"),
1,
["1.0042", "5.0042", "3.0042"],
),
# scatter via boolmask
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal(5),
cudf.Series([True, False, True]),
["5", "2", "5"],
),
(
["1.5", "2.5", "3.5"],
Decimal64Dtype(2, 1),
Decimal("5.5"),
cudf.Series([True, True, True]),
["5.5", "5.5", "5.5"],
),
(
["1.0042", "2.0042", "3.0042"],
Decimal64Dtype(5, 4),
Decimal("5.0042"),
cudf.Series([False, False, True]),
["1.0042", "2.0042", "5.0042"],
),
# We will allow assigning a decimal with less precision
(
["1.00", "2.00", "3.00"],
Decimal64Dtype(3, 2),
Decimal(5),
1,
["1.00", "5.00", "3.00"],
),
# But not truncation
(
["1", "2", "3"],
Decimal64Dtype(1, 0),
Decimal("5.5"),
1,
pa.lib.ArrowInvalid,
),
],
)
def test_series_setitem_decimal(args):
data, dtype, item, to, expect = args
data = _decimal_series(data, dtype)

if expect is pa.lib.ArrowInvalid:
with pytest.raises(expect):
data[to] = item
return
else:
expect = _decimal_series(expect, dtype)
data[to] = item
assert_eq(data, expect)
29 changes: 16 additions & 13 deletions python/cudf/cudf/tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import cudf
Expand Down Expand Up @@ -194,14 +195,13 @@ def test_scalar_roundtrip(value):
+ TEST_DECIMAL_TYPES,
)
def test_null_scalar(dtype):
if isinstance(dtype, cudf.Decimal64Dtype):
with pytest.raises(NotImplementedError):
s = cudf.Scalar(None, dtype=dtype)
return

s = cudf.Scalar(None, dtype=dtype)
assert s.value is cudf.NA
assert s.dtype == np.dtype(dtype)
assert s.dtype == (
np.dtype(dtype)
if not isinstance(dtype, cudf.Decimal64Dtype)
else dtype
)
assert s.is_valid() is False


Expand Down Expand Up @@ -229,7 +229,11 @@ def test_nat_to_null_scalar_succeeds(value):
"value", [None, np.datetime64("NaT"), np.timedelta64("NaT")]
)
def test_generic_null_scalar_construction_fails(value):
with pytest.raises(TypeError):
if value is None:
error = TypeError
else:
error = pa.lib.ArrowNotImplementedError
with pytest.raises(error):
cudf.Scalar(value)


Expand All @@ -242,14 +246,13 @@ def test_generic_null_scalar_construction_fails(value):
+ TEST_DECIMAL_TYPES,
)
def test_scalar_dtype_and_validity(dtype):
if isinstance(dtype, cudf.Decimal64Dtype):
with pytest.raises(NotImplementedError):
s = cudf.Scalar(None, dtype=dtype)
return

s = cudf.Scalar(1, dtype=dtype)

assert s.dtype == np.dtype(dtype)
assert s.dtype == (
np.dtype(dtype)
if not isinstance(dtype, cudf.Decimal64Dtype)
else dtype
)
assert s.is_valid() is True


Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def get_appropriate_dispatched_func(
def _cast_to_appropriate_cudf_type(val, index=None):
# Handle scalar
if val.ndim == 0:
return cudf.Scalar(val).value
return to_cudf_compatible_scalar(val)
# 1D array
elif (val.ndim == 1) or (val.ndim == 2 and val.shape[1] == 1):
# if index is not None and is of a different length
Expand Down