diff --git a/python/cudf/cudf/__init__.py b/python/cudf/cudf/__init__.py index 4dadf6a1869..049cec77d9c 100644 --- a/python/cudf/cudf/__init__.py +++ b/python/cudf/cudf/__init__.py @@ -1,4 +1,5 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. + from cudf.utils.gpu_utils import validate_setup validate_setup() @@ -51,6 +52,7 @@ CategoricalDtype, Decimal64Dtype, Decimal32Dtype, + Decimal128Dtype, IntervalDtype, ListDtype, StructDtype, diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 5e0ee3136b7..653e6b90b3f 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import cupy as cp import numpy as np @@ -8,12 +8,7 @@ import rmm import cudf import cudf._lib as libcudfxx -from cudf.api.types import ( - is_categorical_dtype, - is_decimal_dtype, - is_list_dtype, - is_struct_dtype, -) +from cudf.api.types import is_categorical_dtype, is_list_dtype, is_struct_dtype from cudf.core.buffer import Buffer from cpython.buffer cimport PyObject_CheckBuffer diff --git a/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd b/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd index 930ebaa1bea..b5e9b0ba06b 100644 --- a/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd +++ b/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libc.stdint cimport int32_t, int64_t from libcpp cimport bool @@ -59,6 +59,9 @@ cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil: fixed_point_scalar(int64_t value, scale_type scale, bool is_valid) except + + fixed_point_scalar(data_type value, + scale_type scale, + bool is_valid) except + int64_t value() except + # TODO: Figure out how to add an int32 overload of value() diff --git a/python/cudf/cudf/_lib/cpp/types.pxd b/python/cudf/cudf/_lib/cpp/types.pxd index 1f2094b3958..23727a20ec2 100644 --- a/python/cudf/cudf/_lib/cpp/types.pxd +++ b/python/cudf/cudf/_lib/cpp/types.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libc.stdint cimport int32_t, uint32_t @@ -79,6 +79,7 @@ cdef extern from "cudf/types.hpp" namespace "cudf" nogil: DURATION_NANOSECONDS "cudf::type_id::DURATION_NANOSECONDS" DECIMAL32 "cudf::type_id::DECIMAL32" DECIMAL64 "cudf::type_id::DECIMAL64" + DECIMAL128 "cudf::type_id::DECIMAL128" ctypedef enum hash_id "cudf::hash_id": HASH_IDENTITY "cudf::hash_id::HASH_IDENTITY" @@ -102,3 +103,7 @@ cdef extern from "cudf/types.hpp" namespace "cudf" nogil: HIGHER "cudf::interpolation::HIGHER" MIDPOINT "cudf::interpolation::MIDPOINT" NEAREST "cudf::interpolation::NEAREST" + + # A Hack to let cython compile with __int128_t symbol + # https://stackoverflow.com/a/27609033 + ctypedef int int128 "__int128_t" diff --git a/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd b/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd index 628ffef433b..858569fd696 100644 --- a/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd +++ b/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd @@ -1,12 +1,17 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. + from libc.stdint cimport int32_t, int64_t +from cudf._lib.cpp.types cimport int128 + cdef extern from "cudf/fixed_point/fixed_point.hpp" namespace "numeric" nogil: # cython type stub to help resolve to numeric::decimal64 ctypedef int64_t decimal64 # cython type stub to help resolve to numeric::decimal32 ctypedef int64_t decimal32 + # cython type stub to help resolve to numeric::decimal128 + ctypedef int128 decimal128 cdef cppclass scale_type: scale_type(int32_t) diff --git a/python/cudf/cudf/_lib/orc.pyx b/python/cudf/cudf/_lib/orc.pyx index bf761c30bc8..cbba1796c26 100644 --- a/python/cudf/cudf/_lib/orc.pyx +++ b/python/cudf/cudf/_lib/orc.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import cudf @@ -249,7 +249,6 @@ cdef orc_reader_options make_orc_reader_options( .timestamp_type(data_type(timestamp_type)) .use_index(use_index) .decimal_cols_as_float(c_decimal_cols_as_float) - .decimal128(False) .build() ) diff --git a/python/cudf/cudf/_lib/scalar.pyx b/python/cudf/cudf/_lib/scalar.pyx index 43c0198f80a..32d6cb2ea6d 100644 --- a/python/cudf/cudf/_lib/scalar.pyx +++ b/python/cudf/cudf/_lib/scalar.pyx @@ -1,4 +1,5 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. + import decimal import numpy as np @@ -45,7 +46,12 @@ from cudf._lib.cpp.scalar.scalar cimport ( struct_scalar, timestamp_scalar, ) -from cudf._lib.cpp.wrappers.decimals cimport decimal32, decimal64, scale_type +from cudf._lib.cpp.wrappers.decimals cimport ( + decimal32, + decimal64, + decimal128, + scale_type, +) from cudf._lib.cpp.wrappers.durations cimport ( duration_ms, duration_ns, @@ -88,7 +94,7 @@ cdef class DeviceScalar: # IMPORTANT: this should only ever be called from __init__ valid = not _is_null_host_scalar(value) - if isinstance(dtype, (cudf.Decimal64Dtype, cudf.Decimal32Dtype)): + if isinstance(dtype, cudf.core.dtypes.DecimalDtype): _set_decimal_from_scalar( self.c_value, value, dtype, valid) elif isinstance(dtype, cudf.ListDtype): @@ -118,7 +124,7 @@ cdef class DeviceScalar: ) def _to_host_scalar(self): - if isinstance(self.dtype, (cudf.Decimal64Dtype, cudf.Decimal32Dtype)): + if isinstance(self.dtype, cudf.core.dtypes.DecimalDtype): result = _get_py_decimal_from_fixed_point(self.c_value) elif cudf.api.types.is_struct_dtype(self.dtype): result = _get_py_dict_from_struct(self.c_value) @@ -181,6 +187,7 @@ cdef class DeviceScalar: s.c_value = move(ptr) cdtype = s.get_raw_ptr()[0].type() + if cdtype.id() == libcudf_types.DECIMAL64 and dtype is None: raise TypeError( "Must pass a dtype when constructing from a fixed-point scalar" @@ -322,6 +329,12 @@ cdef _set_decimal_from_scalar(unique_ptr[scalar]& s, np.int32(value), scale_type(-dtype.scale), valid ) ) + elif isinstance(dtype, cudf.Decimal128Dtype): + s.reset( + new fixed_point_scalar[decimal128]( + value, scale_type(-dtype.scale), valid + ) + ) else: raise ValueError(f"dtype not supported: {dtype}") @@ -463,6 +476,10 @@ cdef _get_py_decimal_from_fixed_point(unique_ptr[scalar]& s): rep_val = int((s_ptr)[0].value()) scale = int((s_ptr)[0].type().scale()) return decimal.Decimal(rep_val).scaleb(scale) + elif cdtype.id() == libcudf_types.DECIMAL128: + rep_val = int((s_ptr)[0].value()) + scale = int((s_ptr)[0].type().scale()) + return decimal.Decimal(rep_val).scaleb(scale) else: raise ValueError("Could not convert cudf::scalar to numpy scalar") diff --git a/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx b/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx index 54e85d8833f..dfc9cae915f 100644 --- a/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx +++ b/python/cudf/cudf/_lib/strings/convert/convert_fixed_point.pyx @@ -1,7 +1,9 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. import numpy as np +import cudf + from cudf._lib.column cimport Column from cudf._lib.types import SUPPORTED_NUMPY_TO_LIBCUDF_TYPES @@ -17,7 +19,13 @@ from cudf._lib.cpp.strings.convert.convert_fixed_point cimport ( is_fixed_point as cpp_is_fixed_point, to_fixed_point as cpp_to_fixed_point, ) -from cudf._lib.cpp.types cimport DECIMAL64, data_type, type_id +from cudf._lib.cpp.types cimport ( + DECIMAL32, + DECIMAL64, + DECIMAL128, + data_type, + type_id, +) from cudf._lib.types cimport underlying_type_t_type_id @@ -60,7 +68,15 @@ def to_decimal(Column input_col, object out_type): cdef column_view input_column_view = input_col.view() cdef unique_ptr[column] c_result cdef int scale = out_type.scale - cdef data_type c_out_type = data_type(DECIMAL64, -scale) + cdef data_type c_out_type + if isinstance(out_type, cudf.Decimal32Dtype): + c_out_type = data_type(DECIMAL32, -scale) + elif isinstance(out_type, cudf.Decimal64Dtype): + c_out_type = data_type(DECIMAL64, -scale) + elif isinstance(out_type, cudf.Decimal128Dtype): + c_out_type = data_type(DECIMAL128, -scale) + else: + raise TypeError("should be a decimal dtype") with nogil: c_result = move( cpp_to_fixed_point( diff --git a/python/cudf/cudf/_lib/types.pyx b/python/cudf/cudf/_lib/types.pyx index 1fa389f408c..0a05fd240f3 100644 --- a/python/cudf/cudf/_lib/types.pyx +++ b/python/cudf/cudf/_lib/types.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from enum import IntEnum @@ -66,6 +66,7 @@ class TypeId(IntEnum): ) DECIMAL32 = libcudf_types.type_id.DECIMAL32 DECIMAL64 = libcudf_types.type_id.DECIMAL64 + DECIMAL128 = libcudf_types.type_id.DECIMAL128 SUPPORTED_NUMPY_TO_LIBCUDF_TYPES = { @@ -206,6 +207,11 @@ cdef dtype_from_column_view(column_view cv): precision=cudf.Decimal32Dtype.MAX_PRECISION, scale=-cv.type().scale() ) + elif tid == libcudf_types.type_id.DECIMAL128: + return cudf.Decimal128Dtype( + precision=cudf.Decimal128Dtype.MAX_PRECISION, + scale=-cv.type().scale() + ) else: return LIBCUDF_TO_SUPPORTED_NUMPY_TYPES[ (tid) @@ -216,6 +222,8 @@ cdef libcudf_types.data_type dtype_to_data_type(dtype) except *: tid = libcudf_types.type_id.LIST elif cudf.api.types.is_struct_dtype(dtype): tid = libcudf_types.type_id.STRUCT + elif cudf.api.types.is_decimal128_dtype(dtype): + tid = libcudf_types.type_id.DECIMAL128 elif cudf.api.types.is_decimal64_dtype(dtype): tid = libcudf_types.type_id.DECIMAL64 elif cudf.api.types.is_decimal32_dtype(dtype): @@ -232,6 +240,7 @@ cdef libcudf_types.data_type dtype_to_data_type(dtype) except *: cdef bool is_decimal_type_id(libcudf_types.type_id tid) except *: return tid in ( + libcudf_types.type_id.DECIMAL128, libcudf_types.type_id.DECIMAL64, - libcudf_types.type_id.DECIMAL32 + libcudf_types.type_id.DECIMAL32, ) diff --git a/python/cudf/cudf/api/types.py b/python/cudf/cudf/api/types.py index 10bbb620715..6d5387591cb 100644 --- a/python/cudf/cudf/api/types.py +++ b/python/cudf/cudf/api/types.py @@ -1,4 +1,5 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. + """Define common type operations.""" from __future__ import annotations @@ -20,6 +21,7 @@ is_categorical_dtype, is_decimal32_dtype, is_decimal64_dtype, + is_decimal128_dtype, is_decimal_dtype, is_interval_dtype, is_list_dtype, @@ -41,19 +43,23 @@ def is_numeric_dtype(obj): Whether or not the array or dtype is of a numeric dtype. """ if isclass(obj): - if issubclass(obj, (cudf.Decimal32Dtype, cudf.Decimal64Dtype)): + if issubclass(obj, cudf.core.dtypes.DecimalDtype): return True if issubclass(obj, _BaseDtype): return False else: - if isinstance(obj, cudf.Decimal32Dtype) or isinstance( - getattr(obj, "dtype", None), cudf.Decimal32Dtype + if isinstance(obj, cudf.Decimal128Dtype) or isinstance( + getattr(obj, "dtype", None), cudf.Decimal128Dtype ): return True if isinstance(obj, cudf.Decimal64Dtype) or isinstance( getattr(obj, "dtype", None), cudf.Decimal64Dtype ): return True + if isinstance(obj, cudf.Decimal32Dtype) or isinstance( + getattr(obj, "dtype", None), cudf.Decimal32Dtype + ): + return True if isinstance(obj, _BaseDtype) or isinstance( getattr(obj, "dtype", None), _BaseDtype ): diff --git a/python/cudf/cudf/core/column/__init__.py b/python/cudf/cudf/core/column/__init__.py index 5a44d7c58a6..96e2a7554cf 100644 --- a/python/cudf/cudf/core/column/__init__.py +++ b/python/cudf/cudf/core/column/__init__.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. + """ isort: skip_file """ @@ -31,5 +32,7 @@ from cudf.core.column.decimal import ( # noqa: F401 Decimal32Column, Decimal64Column, + Decimal128Column, + DecimalBaseColumn, ) from cudf.core.column.interval import IntervalColumn # noqa: F401 diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index a966276842f..667ce0488cd 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from __future__ import annotations @@ -50,6 +50,7 @@ is_categorical_dtype, is_decimal32_dtype, is_decimal64_dtype, + is_decimal128_dtype, is_decimal_dtype, is_dtype_equal, is_integer_dtype, @@ -295,8 +296,6 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase: array.type, pd.core.arrays._arrow_utils.ArrowIntervalType ): return cudf.core.column.IntervalColumn.from_arrow(array) - elif isinstance(array.type, pa.Decimal128Type): - return cudf.core.column.Decimal64Column.from_arrow(array) result = libcudf.interop.from_arrow(data, data.column_names)[0]["None"] @@ -987,6 +986,11 @@ def as_decimal_column( ) -> Union["cudf.core.column.decimal.DecimalBaseColumn"]: raise NotImplementedError + def as_decimal128_column( + self, dtype: Dtype, **kwargs + ) -> "cudf.core.column.Decimal128Column": + raise NotImplementedError + def as_decimal64_column( self, dtype: Dtype, **kwargs ) -> "cudf.core.column.Decimal64Column": @@ -1481,6 +1485,18 @@ def build_column( null_count=null_count, children=children, ) + elif is_decimal128_dtype(dtype): + if size is None: + raise TypeError("Must specify size") + return cudf.core.column.Decimal128Column( + data=data, + size=size, + offset=offset, + dtype=dtype, + mask=mask, + null_count=null_count, + children=children, + ) elif is_interval_dtype(dtype): return cudf.core.column.IntervalColumn( dtype=dtype, @@ -1838,7 +1854,7 @@ def as_column( else: pyarrow_array = pa.array(arbitrary, from_pandas=nan_as_null) if isinstance(pyarrow_array.type, pa.Decimal128Type): - pyarrow_type = cudf.Decimal64Dtype.from_arrow( + pyarrow_type = cudf.Decimal128Dtype.from_arrow( pyarrow_array.type ) else: @@ -2040,7 +2056,15 @@ def as_column( # https://github.com/apache/arrow/pull/9948 # Hence we should let the exception propagate to # the user. - if isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): + if isinstance(dtype, cudf.core.dtypes.Decimal128Dtype): + data = pa.array( + arbitrary, + type=pa.decimal128( + precision=dtype.precision, scale=dtype.scale + ), + ) + return cudf.core.column.Decimal128Column.from_arrow(data) + elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): data = pa.array( arbitrary, type=pa.decimal128( @@ -2048,7 +2072,7 @@ def as_column( ), ) return cudf.core.column.Decimal64Column.from_arrow(data) - if isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): + elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): data = pa.array( arbitrary, type=pa.decimal128( @@ -2056,6 +2080,7 @@ def as_column( ), ) return cudf.core.column.Decimal32Column.from_arrow(data) + pa_type = None np_type = None try: @@ -2074,7 +2099,17 @@ def as_column( ) and not isinstance(dtype, cudf.IntervalDtype): data = pa.array(arbitrary, type=dtype.to_arrow()) return as_column(data, nan_as_null=nan_as_null) - if isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): + elif isinstance(dtype, cudf.core.dtypes.Decimal128Dtype): + data = pa.array( + arbitrary, + type=pa.decimal128( + precision=dtype.precision, scale=dtype.scale + ), + ) + return cudf.core.column.Decimal128Column.from_arrow( + data + ) + elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): data = pa.array( arbitrary, type=pa.decimal128( @@ -2084,7 +2119,7 @@ def as_column( return cudf.core.column.Decimal64Column.from_arrow( data ) - if isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): + elif isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): data = pa.array( arbitrary, type=pa.decimal128( @@ -2094,6 +2129,7 @@ def as_column( return cudf.core.column.Decimal32Column.from_arrow( data ) + if is_bool_dtype(dtype): # Need this special case handling for bool dtypes, # since 'boolean' & 'pd.BooleanDtype' are not diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 7037b8e6f36..a17cace3c81 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. from decimal import Decimal from typing import Any, Sequence, Tuple, Union, cast @@ -18,22 +18,27 @@ from cudf.api.types import is_integer_dtype, is_scalar from cudf.core.buffer import Buffer from cudf.core.column import ColumnBase, as_column -from cudf.core.dtypes import Decimal32Dtype, Decimal64Dtype +from cudf.core.dtypes import ( + Decimal32Dtype, + Decimal64Dtype, + Decimal128Dtype, + DecimalDtype, +) from cudf.utils.utils import pa_mask_buffer_to_mask from .numerical_base import NumericalBaseColumn class DecimalBaseColumn(NumericalBaseColumn): - """Base column for decimal64 and decimal32 columns""" + """Base column for decimal32, decimal64 or decimal128 columns""" - dtype: Union[Decimal32Dtype, Decimal64Dtype] + dtype: DecimalDtype def as_decimal_column( self, dtype: Dtype, **kwargs ) -> Union["DecimalBaseColumn"]: if ( - isinstance(dtype, (Decimal64Dtype, Decimal32Dtype)) + isinstance(dtype, cudf.core.dtypes.DecimalDtype) and dtype.scale < self.dtype.scale ): warn( @@ -45,6 +50,126 @@ def as_decimal_column( return self return libcudf.unary.cast(self, dtype) + def as_string_column( + self, dtype: Dtype, format=None, **kwargs + ) -> "cudf.core.column.StringColumn": + if len(self) > 0: + return cpp_from_decimal(self) + else: + return cast( + "cudf.core.column.StringColumn", as_column([], dtype="object") + ) + + def binary_operator(self, op, other, reflect=False): + if reflect: + self, other = other, self + + if not isinstance( + other, + ( + DecimalBaseColumn, + cudf.core.column.NumericalColumn, + cudf.Scalar, + ), + ): + raise TypeError( + f"Operator {op} not supported between" + f"{str(type(self))} and {str(type(other))}" + ) + elif isinstance( + other, cudf.core.column.NumericalColumn + ) and not is_integer_dtype(other.dtype): + raise TypeError( + f"Only decimal and integer column is supported for {op}." + ) + if isinstance(other, cudf.core.column.NumericalColumn): + other = other.as_decimal_column( + self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0) + ) + if not isinstance(self.dtype, other.dtype.__class__): + if ( + self.dtype.precision == other.dtype.precision + and self.dtype.scale == other.dtype.scale + ): + other = other.astype(self.dtype) + + # Binary Arithmetics between decimal columns. `Scale` and `precision` + # are computed outside of libcudf + try: + if op in ("add", "sub", "mul", "div"): + output_type = _get_decimal_type(self.dtype, other.dtype, op) + result = libcudf.binaryop.binaryop( + self, other, op, output_type + ) + result.dtype.precision = output_type.precision + elif op in ("eq", "ne", "lt", "gt", "le", "ge"): + result = libcudf.binaryop.binaryop(self, other, op, bool) + except RuntimeError as e: + if "Unsupported operator for these types" in str(e): + raise NotImplementedError( + f"{op} not supported for types with different bit-widths" + ) from e + raise + + return result + + def fillna( + self, value: Any = None, method: str = None, dtype: Dtype = None + ): + """Fill null values with ``value``. + + Returns a copy with null filled. + """ + if isinstance(value, (int, Decimal)): + value = cudf.Scalar(value, dtype=self.dtype) + elif ( + isinstance(value, DecimalBaseColumn) + or isinstance(value, cudf.core.column.NumericalColumn) + and is_integer_dtype(value.dtype) + ): + value = value.astype(self.dtype) + else: + raise TypeError( + "Decimal columns only support using fillna with decimal and " + "integer values" + ) + + result = libcudf.replace.replace_nulls( + input_col=self, replacement=value, method=method, dtype=dtype + ) + return result._with_type_metadata(self.dtype) + + def normalize_binop_value(self, other): + if is_scalar(other) and isinstance(other, (int, np.int, Decimal)): + return cudf.Scalar(Decimal(other)) + elif isinstance(other, cudf.Scalar) and isinstance( + other.dtype, cudf.core.dtypes.DecimalDtype + ): + return other + else: + raise TypeError(f"cannot normalize {type(other)}") + + def _decimal_quantile( + self, q: Union[float, Sequence[float]], interpolation: str, exact: bool + ) -> ColumnBase: + quant = [float(q)] if not isinstance(q, (Sequence, np.ndarray)) else q + # get sorted indices and exclude nulls + sorted_indices = self.as_frame()._get_sorted_inds( + ascending=True, na_position="first" + ) + sorted_indices = sorted_indices[self.null_count :] + + result = cpp_quantile( + self, quant, interpolation, sorted_indices, exact + ) + + return result._with_type_metadata(self.dtype) + + def as_numerical_column( + self, dtype: Dtype, **kwargs + ) -> "cudf.core.column.NumericalColumn": + return libcudf.unary.cast(self, dtype) + class Decimal32Column(DecimalBaseColumn): dtype: Decimal32Dtype @@ -98,6 +223,35 @@ def to_arrow(self): buffers=[mask_buf, data_buf], ) + def _with_type_metadata( + self: "cudf.core.column.Decimal32Column", dtype: Dtype + ) -> "cudf.core.column.Decimal32Column": + if isinstance(dtype, Decimal32Dtype): + self.dtype.precision = dtype.precision + + return self + + +class Decimal128Column(DecimalBaseColumn): + dtype: Decimal128Dtype + + @classmethod + def from_arrow(cls, data: pa.Array): + result = cast(Decimal128Dtype, super().from_arrow(data)) + result.dtype.precision = data.type.precision + return result + + def to_arrow(self): + return super().to_arrow().cast(self.dtype.to_arrow()) + + def _with_type_metadata( + self: "cudf.core.column.Decimal128Column", dtype: Dtype + ) -> "cudf.core.column.Decimal128Column": + if isinstance(dtype, Decimal128Dtype): + self.dtype.precision = dtype.precision + + return self + class Decimal64Column(DecimalBaseColumn): dtype: Decimal64Dtype @@ -156,114 +310,6 @@ def to_arrow(self): buffers=[mask_buf, data_buf], ) - def binary_operator(self, op, other, reflect=False): - if reflect: - self, other = other, self - - # Binary Arithmetics between decimal columns. `Scale` and `precision` - # are computed outside of libcudf - if op in ("add", "sub", "mul", "div"): - scale = _binop_scale(self.dtype, other.dtype, op) - output_type = Decimal64Dtype( - scale=scale, precision=Decimal64Dtype.MAX_PRECISION - ) # precision will be ignored, libcudf has no notion of precision - result = libcudf.binaryop.binaryop(self, other, op, output_type) - result.dtype.precision = _binop_precision( - self.dtype, other.dtype, op - ) - elif op in ("eq", "ne", "lt", "gt", "le", "ge"): - if not isinstance( - other, - ( - Decimal64Column, - cudf.core.column.NumericalColumn, - cudf.Scalar, - ), - ): - raise TypeError( - f"Operator {op} not supported between" - f"{str(type(self))} and {str(type(other))}" - ) - if isinstance( - other, cudf.core.column.NumericalColumn - ) and not is_integer_dtype(other.dtype): - raise TypeError( - f"Only decimal and integer column is supported for {op}." - ) - if isinstance(other, cudf.core.column.NumericalColumn): - other = other.as_decimal_column( - Decimal64Dtype(Decimal64Dtype.MAX_PRECISION, 0) - ) - result = libcudf.binaryop.binaryop(self, other, op, bool) - return result - - def normalize_binop_value(self, other): - if is_scalar(other) and isinstance(other, (int, np.int, Decimal)): - return cudf.Scalar(Decimal(other)) - elif isinstance(other, cudf.Scalar) and isinstance( - other.dtype, cudf.Decimal64Dtype - ): - return other - else: - raise TypeError(f"cannot normalize {type(other)}") - - def _decimal_quantile( - self, q: Union[float, Sequence[float]], interpolation: str, exact: bool - ) -> ColumnBase: - quant = [float(q)] if not isinstance(q, (Sequence, np.ndarray)) else q - # get sorted indices and exclude nulls - sorted_indices = self.as_frame()._get_sorted_inds( - ascending=True, na_position="first" - ) - sorted_indices = sorted_indices[self.null_count :] - - result = cpp_quantile( - self, quant, interpolation, sorted_indices, exact - ) - - return result._with_type_metadata(self.dtype) - - def as_numerical_column( - self, dtype: Dtype, **kwargs - ) -> "cudf.core.column.NumericalColumn": - return libcudf.unary.cast(self, dtype) - - def as_string_column( - self, dtype: Dtype, format=None, **kwargs - ) -> "cudf.core.column.StringColumn": - if len(self) > 0: - return cpp_from_decimal(self) - else: - return cast( - "cudf.core.column.StringColumn", as_column([], dtype="object") - ) - - def fillna( - self, value: Any = None, method: str = None, dtype: Dtype = None - ): - """Fill null values with ``value``. - - Returns a copy with null filled. - """ - if isinstance(value, (int, Decimal)): - value = cudf.Scalar(value, dtype=self.dtype) - elif ( - isinstance(value, Decimal64Column) - or isinstance(value, cudf.core.column.NumericalColumn) - and is_integer_dtype(value.dtype) - ): - value = value.astype(self.dtype) - else: - raise TypeError( - "Decimal columns only support using fillna with decimal and " - "integer values" - ) - - result = libcudf.replace.replace_nulls( - input_col=self, replacement=value, method=method, dtype=dtype - ) - return result._with_type_metadata(self.dtype) - def serialize(self) -> Tuple[dict, list]: header, frames = super().serialize() header["dtype"] = self.dtype.serialize() @@ -291,34 +337,45 @@ def _with_type_metadata( return self -def _binop_scale(l_dtype, r_dtype, op): +def _get_decimal_type(lhs_dtype, rhs_dtype, op): + """ + Returns the resulting decimal type after calculating + precision & scale when performing the binary operation + `op` for the given dtypes. + + For precision & scale calculations see : https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql + """ # noqa: E501 + # This should at some point be hooked up to libcudf's # binary_operation_fixed_point_scale - s1, s2 = l_dtype.scale, r_dtype.scale + + p1, p2 = lhs_dtype.precision, rhs_dtype.precision + s1, s2 = lhs_dtype.scale, rhs_dtype.scale + if op in ("add", "sub"): - return max(s1, s2) + scale = max(s1, s2) + precision = scale + max(p1 - s1, p2 - s2) + 1 elif op == "mul": - return s1 + s2 + scale = s1 + s2 + precision = p1 + p2 + 1 elif op == "div": - return s1 - s2 + scale = max(6, s1 + p2 + 1) + precision = p1 - s1 + s2 + scale else: raise NotImplementedError() + for decimal_type in ( + cudf.Decimal32Dtype, + cudf.Decimal64Dtype, + cudf.Decimal128Dtype, + ): + try: + min_decimal_type = decimal_type(precision=precision, scale=scale) + except ValueError: + # Call to _validate fails, which means we need + # to try the next dtype + pass + else: + return min_decimal_type -def _binop_precision(l_dtype, r_dtype, op): - """ - Returns the result precision when performing the - binary operation `op` for the given dtypes. - - See: https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql - """ # noqa: E501 - p1, p2 = l_dtype.precision, r_dtype.precision - s1, s2 = l_dtype.scale, r_dtype.scale - if op in ("add", "sub"): - result = max(s1, s2) + max(p1 - s1, p2 - s2) + 1 - elif op in ("mul", "div"): - result = p1 + p2 + 1 - else: - raise NotImplementedError() - # TODO - return min(result, cudf.Decimal64Dtype.MAX_PRECISION) + raise OverflowError("Maximum supported decimal type is Decimal128") diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 8f0a858ee34..2b0d7cfea38 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from __future__ import annotations @@ -30,7 +30,12 @@ column, string, ) -from cudf.core.dtypes import CategoricalDtype, Decimal32Dtype, Decimal64Dtype +from cudf.core.dtypes import ( + CategoricalDtype, + Decimal32Dtype, + Decimal64Dtype, + Decimal128Dtype, +) from cudf.utils import cudautils, utils from cudf.utils.dtypes import ( NUMERIC_TYPES, @@ -166,16 +171,20 @@ def binary_operator( ( NumericalColumn, cudf.Scalar, - cudf.core.column.Decimal64Column, - cudf.core.column.Decimal32Column, + cudf.core.column.DecimalBaseColumn, ), ) or np.isscalar(rhs) ): msg = "{!r} operator not supported between {} and {}" raise TypeError(msg.format(binop, type(self), type(rhs))) - if isinstance(rhs, cudf.core.column.Decimal64Column): + if isinstance(rhs, cudf.core.column.Decimal128Column): lhs: Union[ScalarLike, ColumnBase] = self.as_decimal_column( + Decimal128Dtype(Decimal128Dtype.MAX_PRECISION, 0) + ) + return lhs.binary_operator(binop, rhs) + elif isinstance(rhs, cudf.core.column.Decimal64Column): + lhs = self.as_decimal_column( Decimal64Dtype(Decimal64Dtype.MAX_PRECISION, 0) ) return lhs.binary_operator(binop, rhs) @@ -291,7 +300,7 @@ def as_timedelta_column( def as_decimal_column( self, dtype: Dtype, **kwargs - ) -> "cudf.core.column.Decimal64Column": + ) -> "cudf.core.column.DecimalBaseColumn": return libcudf.unary.cast(self, dtype) def as_numerical_column(self, dtype: Dtype, **kwargs) -> NumericalColumn: diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index a83110d273c..9b44b4e6831 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. from __future__ import annotations @@ -5196,7 +5196,7 @@ def as_timedelta_column( def as_decimal_column( self, dtype: Dtype, **kwargs - ) -> "cudf.core.column.Decimal64Column": + ) -> "cudf.core.column.DecimalBaseColumn": return libstrings.to_decimal(self, dtype) def as_string_column( diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 69600426ec0..1dddcb9e3af 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from __future__ import annotations, division @@ -1583,7 +1583,7 @@ def _concat( if isinstance( col, ( - cudf.core.column.Decimal64Column, + cudf.core.column.DecimalBaseColumn, cudf.core.column.StructColumn, ), ): diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 5f21e883a4d..3a1c366b429 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import decimal import pickle @@ -355,14 +355,14 @@ def deserialize(cls, header: dict, frames: list): return cls(fields) -class Decimal32Dtype(_BaseDtype): +class DecimalDtype(_BaseDtype): """ Parameters ---------- precision : int The total number of digits in each value of this dtype scale : int, optional - The scale of the Decimal32Dtype. See Notes below. + The scale of the dtype. See Notes below. Notes ----- @@ -379,9 +379,7 @@ class Decimal32Dtype(_BaseDtype): and *not* representable with precision<6 or scale<4. """ - name = "decimal32" _metadata = ("precision", "scale") - MAX_PRECISION = np.floor(np.log10(np.iinfo("int32").max)) def __init__(self, precision, scale=0): self._validate(precision, scale) @@ -389,7 +387,7 @@ def __init__(self, precision, scale=0): @property def str(self): - return f"decimal32({self.precision}, {self.scale})" + return f"{str(self.name)}({self.precision}, {self.scale})" @property def precision(self): @@ -404,6 +402,10 @@ def precision(self, value): def scale(self): return self._typ.scale + @property + def itemsize(self): + return self.ITEMSIZE + @property def type(self): # might need to account for precision and scale here @@ -416,22 +418,15 @@ def to_arrow(self): def from_arrow(cls, typ): return cls(typ.precision, typ.scale) - @property - def itemsize(self): - return 4 - def __repr__(self): return ( f"{self.__class__.__name__}" f"(precision={self.precision}, scale={self.scale})" ) - def __hash__(self): - return hash(self._typ) - @classmethod def _validate(cls, precision, scale=0): - if precision > Decimal32Dtype.MAX_PRECISION: + if precision > cls.MAX_PRECISION: raise ValueError( f"Cannot construct a {cls.__name__}" f" with precision > {cls.MAX_PRECISION}" @@ -462,113 +457,33 @@ def serialize(self) -> Tuple[dict, list]: def deserialize(cls, header: dict, frames: list): return cls(header["precision"], header["scale"]) - -class Decimal64Dtype(_BaseDtype): - """ - Parameters - ---------- - precision : int - The total number of digits in each value of this dtype - scale : int, optional - The scale of the Decimal64Dtype. See Notes below. - - Notes - ----- - When the scale is positive: - - numbers with fractional parts (e.g., 0.0042) can be represented - - the scale is the total number of digits to the right of the - decimal point - When the scale is negative: - - only multiples of powers of 10 (including 10**0) can be - represented (e.g., 1729, 4200, 1000000) - - the scale represents the number of trailing zeros in the value. - For example, 42 is representable with precision=2 and scale=0. - 13.0051 is representable with precision=6 and scale=4, - and *not* representable with precision<6 or scale<4. - """ - - name = "decimal64" - _metadata = ("precision", "scale") - MAX_PRECISION = np.floor(np.log10(np.iinfo("int64").max)) - - def __init__(self, precision, scale=0): - self._validate(precision, scale) - self._typ = pa.decimal128(precision, scale) - - @property - def str(self): - return f"decimal64({self.precision}, {self.scale})" - - @property - def precision(self): - return self._typ.precision - - @precision.setter - def precision(self, value): - self._validate(value, self.scale) - self._typ = pa.decimal128(precision=value, scale=self.scale) - - @property - def scale(self): - return self._typ.scale - - @property - def type(self): - # might need to account for precision and scale here - return decimal.Decimal - - def to_arrow(self): - return self._typ - - @classmethod - def from_arrow(cls, typ): - return cls(typ.precision, typ.scale) - - @property - def itemsize(self): - return 8 - - def __repr__(self): - return ( - f"{self.__class__.__name__}" - f"(precision={self.precision}, scale={self.scale})" - ) + def __eq__(self, other: Dtype) -> bool: + if other is self: + return True + elif not isinstance(other, self.__class__): + return False + return self.precision == other.precision and self.scale == other.scale def __hash__(self): return hash(self._typ) - @classmethod - def _validate(cls, precision, scale=0): - if precision > Decimal64Dtype.MAX_PRECISION: - raise ValueError( - f"Cannot construct a {cls.__name__}" - f" with precision > {cls.MAX_PRECISION}" - ) - if abs(scale) > precision: - raise ValueError(f"scale={scale} exceeds precision={precision}") - @classmethod - def _from_decimal(cls, decimal): - """ - Create a cudf.Decimal64Dtype from a decimal.Decimal object - """ - metadata = decimal.as_tuple() - precision = max(len(metadata.digits), -metadata.exponent) - return cls(precision, -metadata.exponent) +class Decimal32Dtype(DecimalDtype): + name = "decimal32" + MAX_PRECISION = np.floor(np.log10(np.iinfo("int32").max)) + ITEMSIZE = 4 - def serialize(self) -> Tuple[dict, list]: - return ( - { - "type-serialized": pickle.dumps(type(self)), - "precision": self.precision, - "scale": self.scale, - }, - [], - ) - @classmethod - def deserialize(cls, header: dict, frames: list): - return cls(header["precision"], header["scale"]) +class Decimal64Dtype(DecimalDtype): + name = "decimal64" + MAX_PRECISION = np.floor(np.log10(np.iinfo("int64").max)) + ITEMSIZE = 8 + + +class Decimal128Dtype(DecimalDtype): + name = "decimal128" + MAX_PRECISION = 38 + ITEMSIZE = 16 class IntervalDtype(StructDtype): @@ -740,7 +655,11 @@ def is_decimal_dtype(obj): bool Whether or not the array-like or dtype is of the decimal dtype. """ - return is_decimal32_dtype(obj) or is_decimal64_dtype(obj) + return ( + is_decimal32_dtype(obj) + or is_decimal64_dtype(obj) + or is_decimal128_dtype(obj) + ) def is_interval_dtype(obj): @@ -791,3 +710,15 @@ def is_decimal64_dtype(obj): ) or (hasattr(obj, "dtype") and is_decimal64_dtype(obj.dtype)) ) + + +def is_decimal128_dtype(obj): + return ( + type(obj) is cudf.core.dtypes.Decimal128Dtype + or obj is cudf.core.dtypes.Decimal128Dtype + or ( + isinstance(obj, str) + and obj == cudf.core.dtypes.Decimal128Dtype.name + ) + or (hasattr(obj, "dtype") and is_decimal128_dtype(obj.dtype)) + ) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 5b041ba53b9..6da98bf980d 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import collections import itertools @@ -1003,7 +1003,9 @@ def diff(self, periods=1, axis=0): cudf.core.frame.Frame(value_columns._data) ) grouped = self.obj.__class__._from_data(data, index) - grouped = self._mimic_pandas_order(grouped) + grouped = self._mimic_pandas_order(grouped)._copy_type_metadata( + value_columns + ) result = grouped - self.shift(periods=periods) return result._copy_type_metadata(value_columns) diff --git a/python/cudf/cudf/core/join/_join_helpers.py b/python/cudf/cudf/core/join/_join_helpers.py index 6dec0b10273..ead0cd566d9 100644 --- a/python/cudf/cudf/core/join/_join_helpers.py +++ b/python/cudf/cudf/core/join/_join_helpers.py @@ -1,4 +1,5 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. + from __future__ import annotations import collections @@ -8,7 +9,7 @@ import numpy as np import cudf -from cudf.api.types import is_dtype_equal +from cudf.api.types import is_decimal_dtype, is_dtype_equal from cudf.core.column import CategoricalColumn from cudf.core.dtypes import CategoricalDtype @@ -85,9 +86,7 @@ def _match_join_keys( if is_dtype_equal(ltype, rtype): return lcol, rcol - if isinstance(ltype, cudf.Decimal64Dtype) or isinstance( - rtype, cudf.Decimal64Dtype - ): + if is_decimal_dtype(ltype) or is_decimal_dtype(rtype): raise TypeError( "Decimal columns can only be merged with decimal columns " "of the same precision and scale" diff --git a/python/cudf/cudf/core/scalar.py b/python/cudf/cudf/core/scalar.py index 37bb8e32c5a..b0770b71ca6 100644 --- a/python/cudf/cudf/core/scalar.py +++ b/python/cudf/cudf/core/scalar.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. + import decimal import numpy as np @@ -145,12 +146,12 @@ def _preprocess_host_value(self, value, dtype): else: return NA, dtype - if isinstance(dtype, (cudf.Decimal64Dtype, cudf.Decimal32Dtype)): + if isinstance(dtype, cudf.core.dtypes.DecimalDtype): value = pa.scalar( value, type=pa.decimal128(dtype.precision, dtype.scale) ).as_py() if isinstance(value, decimal.Decimal) and dtype is None: - dtype = cudf.Decimal64Dtype._from_decimal(value) + dtype = cudf.Decimal128Dtype._from_decimal(value) value = to_cudf_compatible_scalar(value, dtype=dtype) @@ -171,7 +172,7 @@ def _preprocess_host_value(self, value, dtype): else: dtype = value.dtype - if not isinstance(dtype, (cudf.Decimal64Dtype, cudf.Decimal32Dtype)): + if not isinstance(dtype, cudf.core.dtypes.DecimalDtype): dtype = cudf.dtype(dtype) if not valid: diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 6842a05a505..2ecee781eb1 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from __future__ import annotations @@ -133,11 +133,7 @@ def __setitem__(self, key, value): if ( not isinstance( self._frame._column.dtype, - ( - cudf.Decimal64Dtype, - cudf.Decimal32Dtype, - cudf.CategoricalDtype, - ), + (cudf.core.dtypes.DecimalDtype, cudf.CategoricalDtype), ) and hasattr(value, "dtype") and _is_non_decimal_numeric_dtype(value.dtype) @@ -1466,7 +1462,10 @@ def _concat(cls, objs, axis=0, index=True): # Reassign precision for decimal cols & type schema for struct cols if isinstance( col, - (cudf.core.column.Decimal64Column, cudf.core.column.StructColumn), + ( + cudf.core.column.DecimalBaseColumn, + cudf.core.column.StructColumn, + ), ): col = col._with_type_metadata(objs[0].dtype) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 9694d19e159..3e73e0c9e3d 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2020, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. import io import json @@ -10,7 +10,6 @@ import fsspec import numpy as np -import pyarrow as pa from pyarrow import dataset as ds, parquet as pq import cudf @@ -627,34 +626,6 @@ def _read_parquet( # Simple helper function to dispatch between # cudf and pyarrow to read parquet data if engine == "cudf": - # Temporary error to probe a parquet file - # and raise decimal128 support error. - if len(filepaths_or_buffers) > 0: - try: - metadata = pq.read_metadata(filepaths_or_buffers[0]) - except TypeError: - # pq.read_metadata only supports reading metadata from - # certain types of file inputs, like str-filepath or file-like - # objects, and errors for the rest of inputs. Hence this is - # to avoid failing on other types of file inputs. - pass - else: - arrow_schema = metadata.schema.to_arrow_schema() - check_cols = arrow_schema.names if columns is None else columns - for col_name, arrow_type in zip( - arrow_schema.names, arrow_schema.types - ): - if col_name not in check_cols: - continue - if isinstance(arrow_type, pa.ListType): - val_field_types = arrow_type.value_field.flatten() - for val_field_type in val_field_types: - _check_decimal128_type(val_field_type.type) - elif isinstance(arrow_type, pa.StructType): - _ = cudf.StructDtype.from_arrow(arrow_type) - else: - _check_decimal128_type(arrow_type) - return libparquet.read_parquet( filepaths_or_buffers, columns=columns, @@ -982,11 +953,3 @@ def __enter__(self): def __exit__(self, *args): self.close() - - -def _check_decimal128_type(arrow_type): - if isinstance(arrow_type, pa.Decimal128Type): - if arrow_type.precision > cudf.Decimal64Dtype.MAX_PRECISION: - raise NotImplementedError( - "Decimal type greater than Decimal64 is not yet supported" - ) diff --git a/python/cudf/cudf/testing/dataset_generator.py b/python/cudf/cudf/testing/dataset_generator.py index f4a80c60ddf..13be158ed78 100644 --- a/python/cudf/cudf/testing/dataset_generator.py +++ b/python/cudf/cudf/testing/dataset_generator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # This module is for generating "synthetic" datasets. It was originally # designed for testing filtered reading. Generally, it should be useful @@ -384,6 +384,22 @@ def rand_dataframe( dtype=dtype, ) ) + elif dtype == "decimal128": + max_precision = meta.get( + "max_precision", cudf.Decimal128Dtype.MAX_PRECISION + ) + precision = np.random.randint(1, max_precision) + scale = np.random.randint(0, precision) + dtype = cudf.Decimal128Dtype(precision=precision, scale=scale) + column_params.append( + ColumnParameters( + cardinality=cardinality, + null_frequency=null_frequency, + generator=decimal_generator(dtype=dtype, size=cardinality), + is_sorted=False, + dtype=dtype, + ) + ) elif dtype == "category": column_params.append( ColumnParameters( diff --git a/python/cudf/cudf/tests/test_api_types.py b/python/cudf/cudf/tests/test_api_types.py index 98249e761c1..4d104c122d1 100644 --- a/python/cudf/cudf/tests/test_api_types.py +++ b/python/cudf/cudf/tests/test_api_types.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. import numpy as np import pandas as pd @@ -84,13 +84,17 @@ (cudf.CategoricalDtype, True), (cudf.ListDtype, False), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, False), (cudf.Decimal64Dtype, False), + (cudf.Decimal32Dtype, False), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), True), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), False), (cudf.Decimal64Dtype(5, 2), False), + (cudf.Decimal32Dtype(5, 2), False), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -100,7 +104,9 @@ (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), True), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), False), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), False), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), False), # TODO: Currently creating an empty Series of list type ignores the # provided type and instead makes a float64 Series. (cudf.Series([[1, 2], [3, 4, 5]]), False), @@ -189,13 +195,17 @@ def test_is_categorical_dtype(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, False), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, True), (cudf.Decimal64Dtype, True), + (cudf.Decimal32Dtype, True), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), True), (cudf.Decimal64Dtype(5, 2), True), + (cudf.Decimal32Dtype(5, 2), True), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), True), @@ -205,7 +215,9 @@ def test_is_categorical_dtype(obj, expect): (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), True), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), True), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), True), (cudf.Series([[1, 2], [3, 4, 5]]), False), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), False), (cudf.Series(dtype=cudf.IntervalDtype(int)), False), @@ -290,13 +302,17 @@ def test_is_numeric_dtype(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, False), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, False), (cudf.Decimal64Dtype, False), + (cudf.Decimal32Dtype, False), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), False), (cudf.Decimal64Dtype(5, 2), False), + (cudf.Decimal32Dtype(5, 2), False), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -306,7 +322,9 @@ def test_is_numeric_dtype(obj, expect): (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), False), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), False), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), False), (cudf.Series([[1, 2], [3, 4, 5]]), False), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), False), (cudf.Series(dtype=cudf.IntervalDtype(int)), False), @@ -391,13 +409,17 @@ def test_is_integer_dtype(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, False), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, False), (cudf.Decimal64Dtype, False), + (cudf.Decimal32Dtype, False), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), False), (cudf.Decimal64Dtype(5, 2), False), + (cudf.Decimal32Dtype(5, 2), False), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -407,7 +429,9 @@ def test_is_integer_dtype(obj, expect): (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), False), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), False), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), False), (cudf.Series([[1, 2], [3, 4, 5]]), False), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), False), (cudf.Series(dtype=cudf.IntervalDtype(int)), False), @@ -493,13 +517,17 @@ def test_is_integer(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, False), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, False), (cudf.Decimal64Dtype, False), + (cudf.Decimal32Dtype, False), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), False), (cudf.Decimal64Dtype(5, 2), False), + (cudf.Decimal32Dtype(5, 2), False), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -509,7 +537,9 @@ def test_is_integer(obj, expect): (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), False), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), False), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), False), (cudf.Series([[1, 2], [3, 4, 5]]), False), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), False), (cudf.Series(dtype=cudf.IntervalDtype(int)), False), @@ -594,13 +624,17 @@ def test_is_string_dtype(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, False), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, False), (cudf.Decimal64Dtype, False), + (cudf.Decimal32Dtype, False), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), False), (cudf.Decimal64Dtype(5, 2), False), + (cudf.Decimal32Dtype(5, 2), False), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -610,7 +644,9 @@ def test_is_string_dtype(obj, expect): (cudf.Series(dtype="datetime64[s]"), True), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), False), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), False), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), False), (cudf.Series([[1, 2], [3, 4, 5]]), False), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), False), (cudf.Series(dtype=cudf.IntervalDtype(int)), False), @@ -695,13 +731,17 @@ def test_is_datetime_dtype(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, True), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, False), (cudf.Decimal64Dtype, False), + (cudf.Decimal32Dtype, False), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), True), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), False), (cudf.Decimal64Dtype(5, 2), False), + (cudf.Decimal32Dtype(5, 2), False), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -711,7 +751,9 @@ def test_is_datetime_dtype(obj, expect): (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), False), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), False), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), False), (cudf.Series([[1, 2], [3, 4, 5]]), True), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), False), (cudf.Series(dtype=cudf.IntervalDtype(int)), False), @@ -796,13 +838,17 @@ def test_is_list_dtype(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, False), (cudf.StructDtype, True), + (cudf.Decimal128Dtype, False), (cudf.Decimal64Dtype, False), + (cudf.Decimal32Dtype, False), # (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), True), + (cudf.Decimal128Dtype(5, 2), False), (cudf.Decimal64Dtype(5, 2), False), + (cudf.Decimal32Dtype(5, 2), False), # (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -812,7 +858,9 @@ def test_is_list_dtype(obj, expect): (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), False), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), False), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), False), (cudf.Series([[1, 2], [3, 4, 5]]), False), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), True), # (cudf.Series(dtype=cudf.IntervalDtype(int)), False), @@ -900,13 +948,17 @@ def test_is_struct_dtype(obj, expect): (cudf.CategoricalDtype, False), (cudf.ListDtype, False), (cudf.StructDtype, False), + (cudf.Decimal128Dtype, True), (cudf.Decimal64Dtype, True), + (cudf.Decimal32Dtype, True), (cudf.IntervalDtype, False), # cuDF dtype instances. (cudf.CategoricalDtype("a"), False), (cudf.ListDtype(int), False), (cudf.StructDtype({"a": int}), False), + (cudf.Decimal128Dtype(5, 2), True), (cudf.Decimal64Dtype(5, 2), True), + (cudf.Decimal32Dtype(5, 2), True), (cudf.IntervalDtype(int), False), # cuDF objects (cudf.Series(dtype="bool"), False), @@ -916,7 +968,9 @@ def test_is_struct_dtype(obj, expect): (cudf.Series(dtype="datetime64[s]"), False), (cudf.Series(dtype="timedelta64[s]"), False), (cudf.Series(dtype="category"), False), + (cudf.Series(dtype=cudf.Decimal128Dtype(5, 2)), True), (cudf.Series(dtype=cudf.Decimal64Dtype(5, 2)), True), + (cudf.Series(dtype=cudf.Decimal32Dtype(5, 2)), True), (cudf.Series([[1, 2], [3, 4, 5]]), False), (cudf.Series([{"a": 1, "b": 2}, {"c": 3}]), False), (cudf.Series(dtype=cudf.IntervalDtype(int)), False), diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index ba2a6dce369..921f2de38c2 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. from __future__ import division @@ -1802,7 +1802,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", "2.0"], cudf.Decimal64Dtype(scale=2, precision=3), ["3.0", "4.0"], - cudf.Decimal64Dtype(scale=2, precision=4), + cudf.Decimal32Dtype(scale=2, precision=4), ), ( operator.add, @@ -1811,7 +1811,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", "3.005"], - cudf.Decimal64Dtype(scale=3, precision=5), + cudf.Decimal32Dtype(scale=3, precision=5), ), ( operator.add, @@ -1820,7 +1820,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.1", "0.2"], cudf.Decimal64Dtype(scale=3, precision=4), ["100.1", "200.2"], - cudf.Decimal64Dtype(scale=3, precision=18), + cudf.Decimal128Dtype(scale=3, precision=23), ), ( operator.sub, @@ -1829,7 +1829,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], - cudf.Decimal64Dtype(scale=3, precision=5), + cudf.Decimal32Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1838,7 +1838,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], - cudf.Decimal64Dtype(scale=3, precision=5), + cudf.Decimal32Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1847,7 +1847,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.1", "0.2"], cudf.Decimal64Dtype(scale=6, precision=10), ["99.9", "199.8"], - cudf.Decimal64Dtype(scale=6, precision=18), + cudf.Decimal128Dtype(scale=6, precision=19), ), ( operator.mul, @@ -1856,7 +1856,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", "3.0"], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", "6.0"], - cudf.Decimal64Dtype(scale=5, precision=8), + cudf.Decimal32Dtype(scale=5, precision=8), ), ( operator.mul, @@ -1865,7 +1865,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.1", "0.2"], cudf.Decimal64Dtype(scale=3, precision=4), ["10.0", "40.0"], - cudf.Decimal64Dtype(scale=1, precision=8), + cudf.Decimal32Dtype(scale=1, precision=8), ), ( operator.mul, @@ -1874,7 +1874,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.343", "0.500"], cudf.Decimal64Dtype(scale=3, precision=3), ["343.0", "1000.0"], - cudf.Decimal64Dtype(scale=0, precision=8), + cudf.Decimal32Dtype(scale=0, precision=8), ), ( operator.truediv, @@ -1883,7 +1883,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", "3.0"], cudf.Decimal64Dtype(scale=1, precision=4), ["1.0", "0.6"], - cudf.Decimal64Dtype(scale=1, precision=9), + cudf.Decimal64Dtype(scale=7, precision=10), ), ( operator.truediv, @@ -1892,7 +1892,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.1", "0.2"], cudf.Decimal64Dtype(scale=2, precision=4), ["1000.0", "1000.0"], - cudf.Decimal64Dtype(scale=-3, precision=8), + cudf.Decimal64Dtype(scale=6, precision=12), ), ( operator.truediv, @@ -1901,7 +1901,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.34", "8.50"], cudf.Decimal64Dtype(scale=2, precision=8), ["56.77", "1.79"], - cudf.Decimal64Dtype(scale=2, precision=18), + cudf.Decimal128Dtype(scale=13, precision=25), ), ( operator.add, @@ -1910,7 +1910,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", None, "2.0"], cudf.Decimal64Dtype(scale=1, precision=2), ["3.0", None, "4.0"], - cudf.Decimal64Dtype(scale=1, precision=3), + cudf.Decimal32Dtype(scale=1, precision=3), ), ( operator.add, @@ -1919,7 +1919,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", None], - cudf.Decimal64Dtype(scale=3, precision=5), + cudf.Decimal32Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1928,7 +1928,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], - cudf.Decimal64Dtype(scale=3, precision=5), + cudf.Decimal32Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1937,7 +1937,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], - cudf.Decimal64Dtype(scale=3, precision=5), + cudf.Decimal32Dtype(scale=3, precision=5), ), ( operator.mul, @@ -1946,7 +1946,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", None], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", None], - cudf.Decimal64Dtype(scale=5, precision=8), + cudf.Decimal32Dtype(scale=5, precision=8), ), ( operator.mul, @@ -1955,7 +1955,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.1", None], cudf.Decimal64Dtype(scale=3, precision=12), ["10.0", None], - cudf.Decimal64Dtype(scale=1, precision=18), + cudf.Decimal128Dtype(scale=1, precision=23), ), ( operator.eq, @@ -2128,7 +2128,10 @@ def test_binops_decimal(args): b = utils._decimal_series(rhs, r_dtype) expect = ( utils._decimal_series(expect, expect_dtype) - if isinstance(expect_dtype, cudf.Decimal64Dtype) + if isinstance( + expect_dtype, + (cudf.Decimal64Dtype, cudf.Decimal32Dtype, cudf.Decimal128Dtype), + ) else cudf.Series(expect, dtype=expect_dtype) ) @@ -2322,7 +2325,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal(1), ["101", "201"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), False, ), ( @@ -2331,7 +2334,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), 1, ["101", "201"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), False, ), ( @@ -2340,7 +2343,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal("1.5"), ["101.5", "201.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), False, ), ( @@ -2349,7 +2352,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), cudf.Scalar(decimal.Decimal("1.5")), ["101.5", "201.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), False, ), ( @@ -2358,7 +2361,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal(1), ["101", "201"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), True, ), ( @@ -2367,7 +2370,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), 1, ["101", "201"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), True, ), ( @@ -2376,7 +2379,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal("1.5"), ["101.5", "201.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), True, ), ( @@ -2385,7 +2388,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), cudf.Scalar(decimal.Decimal("1.5")), ["101.5", "201.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), True, ), ( @@ -2394,7 +2397,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), 1, ["100", "200"], - cudf.Decimal64Dtype(scale=-2, precision=5), + cudf.Decimal32Dtype(scale=-2, precision=5), False, ), ( @@ -2403,7 +2406,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal(2), ["200", "400"], - cudf.Decimal64Dtype(scale=-2, precision=5), + cudf.Decimal32Dtype(scale=-2, precision=5), False, ), ( @@ -2412,7 +2415,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal("1.5"), ["150", "300"], - cudf.Decimal64Dtype(scale=-1, precision=6), + cudf.Decimal32Dtype(scale=-1, precision=6), False, ), ( @@ -2421,7 +2424,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), cudf.Scalar(decimal.Decimal("1.5")), ["150", "300"], - cudf.Decimal64Dtype(scale=-1, precision=6), + cudf.Decimal32Dtype(scale=-1, precision=6), False, ), ( @@ -2430,7 +2433,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), 1, ["100", "200"], - cudf.Decimal64Dtype(scale=-2, precision=5), + cudf.Decimal32Dtype(scale=-2, precision=5), True, ), ( @@ -2439,7 +2442,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal(2), ["200", "400"], - cudf.Decimal64Dtype(scale=-2, precision=5), + cudf.Decimal32Dtype(scale=-2, precision=5), True, ), ( @@ -2448,7 +2451,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal("1.5"), ["150", "300"], - cudf.Decimal64Dtype(scale=-1, precision=6), + cudf.Decimal32Dtype(scale=-1, precision=6), True, ), ( @@ -2457,7 +2460,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), cudf.Scalar(decimal.Decimal("1.5")), ["150", "300"], - cudf.Decimal64Dtype(scale=-1, precision=6), + cudf.Decimal32Dtype(scale=-1, precision=6), True, ), ( @@ -2466,7 +2469,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=4), 1, ["1000", "2000"], - cudf.Decimal64Dtype(scale=-2, precision=6), + cudf.Decimal64Dtype(scale=6, precision=12), False, ), ( @@ -2475,7 +2478,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=2, precision=5), decimal.Decimal(2), ["50", "100"], - cudf.Decimal64Dtype(scale=2, precision=7), + cudf.Decimal32Dtype(scale=6, precision=9), False, ), ( @@ -2484,7 +2487,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=2, precision=4), decimal.Decimal("1.5"), ["23.4", "36.6"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=6, precision=9), False, ), ( @@ -2493,7 +2496,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=1, precision=3), cudf.Scalar(decimal.Decimal("1.5")), ["14", "62"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=6, precision=9), False, ), ( @@ -2502,7 +2505,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=2, precision=5), 1, ["0", "0"], - cudf.Decimal64Dtype(scale=-2, precision=7), + cudf.Decimal32Dtype(scale=6, precision=9), True, ), ( @@ -2511,7 +2514,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=1, precision=6), decimal.Decimal(20), ["10", "40"], - cudf.Decimal64Dtype(scale=-1, precision=9), + cudf.Decimal64Dtype(scale=7, precision=10), True, ), ( @@ -2520,7 +2523,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=2, precision=3), decimal.Decimal("8.55"), ["7", "1"], - cudf.Decimal64Dtype(scale=0, precision=7), + cudf.Decimal32Dtype(scale=6, precision=9), True, ), ( @@ -2529,7 +2532,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=1, precision=3), cudf.Scalar(decimal.Decimal("90.84")), ["82.5", "2.1"], - cudf.Decimal64Dtype(scale=1, precision=8), + cudf.Decimal32Dtype(scale=6, precision=9), True, ), ( @@ -2538,7 +2541,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal(2), ["98", "198"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), False, ), ( @@ -2547,7 +2550,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal("2.5"), ["97.5", "197.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), False, ), ( @@ -2556,7 +2559,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), 4, ["96", "196"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), False, ), ( @@ -2565,7 +2568,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), cudf.Scalar(decimal.Decimal("2.5")), ["97.5", "197.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), False, ), ( @@ -2574,7 +2577,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal(2), ["-98", "-198"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), True, ), ( @@ -2583,7 +2586,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), 4, ["-96", "-196"], - cudf.Decimal64Dtype(scale=0, precision=6), + cudf.Decimal32Dtype(scale=0, precision=6), True, ), ( @@ -2592,7 +2595,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), decimal.Decimal("2.5"), ["-97.5", "-197.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), True, ), ( @@ -2601,11 +2604,15 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): cudf.Decimal64Dtype(scale=-2, precision=3), cudf.Scalar(decimal.Decimal("2.5")), ["-97.5", "-197.5"], - cudf.Decimal64Dtype(scale=1, precision=7), + cudf.Decimal32Dtype(scale=1, precision=7), True, ), ], ) +@pytest.mark.xfail( + reason="binop operations not supported for different " + "bit-width decimal types" +) def test_binops_decimal_scalar(args): op, lhs, l_dtype, rhs, expect, expect_dtype, reflect = args @@ -2776,6 +2783,10 @@ def decimal_series(input, dtype): ], ) @pytest.mark.parametrize("reflected", [True, False]) +@pytest.mark.xfail( + reason="binop operations not supported for different bit-width " + "decimal types" +) def test_binops_decimal_scalar_compare(args, reflected): """ Tested compare operations: diff --git a/python/cudf/cudf/tests/test_column.py b/python/cudf/cudf/tests/test_column.py index d2c7c073aa1..e01b952be94 100644 --- a/python/cudf/cudf/tests/test_column.py +++ b/python/cudf/cudf/tests/test_column.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. + import cupy as cp import numpy as np import pandas as pd @@ -116,9 +117,13 @@ def test_column_slicing(pandas_input, offset, size): @pytest.mark.parametrize("size", [50, 10, 0]) @pytest.mark.parametrize("precision", [2, 3, 5]) @pytest.mark.parametrize("scale", [0, 1, 2]) -def test_decimal_column_slicing(offset, size, precision, scale): +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal128Dtype, cudf.Decimal64Dtype, cudf.Decimal32Dtype], +) +def test_decimal_column_slicing(offset, size, precision, scale, decimal_type): col = cudf.core.column.as_column(pd.Series(np.random.rand(1000))) - col = col.astype(cudf.Decimal64Dtype(precision, scale)) + col = col.astype(decimal_type(precision, scale)) column_slicing_test(col, offset, size, True) @@ -379,7 +384,7 @@ def test_as_column_buffer(data, expected): ( pa.array([100, 200, 300], type=pa.decimal128(3)), cudf.core.column.as_column( - [100, 200, 300], dtype=cudf.core.dtypes.Decimal64Dtype(3, 0) + [100, 200, 300], dtype=cudf.core.dtypes.Decimal128Dtype(3, 0) ), ), ( diff --git a/python/cudf/cudf/tests/test_concat.py b/python/cudf/cudf/tests/test_concat.py index 46707a283af..b8724fe36f5 100644 --- a/python/cudf/cudf/tests/test_concat.py +++ b/python/cudf/cudf/tests/test_concat.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. import re from decimal import Decimal @@ -9,7 +9,7 @@ import cudf as gd from cudf.api.types import is_categorical_dtype -from cudf.core.dtypes import Decimal64Dtype +from cudf.core.dtypes import Decimal32Dtype, Decimal64Dtype, Decimal128Dtype from cudf.testing._utils import assert_eq, assert_exceptions_equal @@ -1357,8 +1357,19 @@ def test_concat_single_object(ignore_index, typ): ) -@pytest.mark.parametrize("ltype", [Decimal64Dtype(3, 1), Decimal64Dtype(7, 2)]) -@pytest.mark.parametrize("rtype", [Decimal64Dtype(3, 2), Decimal64Dtype(8, 4)]) +@pytest.mark.parametrize( + "ltype", + [Decimal64Dtype(3, 1), Decimal64Dtype(7, 2), Decimal64Dtype(8, 4)], +) +@pytest.mark.parametrize( + "rtype", + [ + Decimal64Dtype(3, 2), + Decimal64Dtype(8, 4), + gd.Decimal128Dtype(3, 2), + gd.Decimal32Dtype(8, 4), + ], +) def test_concat_decimal_dataframe(ltype, rtype): gdf1 = gd.DataFrame( {"id": np.random.randint(0, 10, 3), "val": ["22.3", "59.5", "81.1"]} @@ -1381,7 +1392,13 @@ def test_concat_decimal_dataframe(ltype, rtype): @pytest.mark.parametrize("ltype", [Decimal64Dtype(4, 1), Decimal64Dtype(8, 2)]) @pytest.mark.parametrize( - "rtype", [Decimal64Dtype(4, 3), Decimal64Dtype(10, 4)] + "rtype", + [ + Decimal64Dtype(4, 3), + Decimal64Dtype(10, 4), + Decimal32Dtype(8, 3), + Decimal128Dtype(18, 3), + ], ) def test_concat_decimal_series(ltype, rtype): gs1 = gd.Series(["228.3", "559.5", "281.1"]).astype(ltype) @@ -1420,7 +1437,7 @@ def test_concat_decimal_series(ltype, rtype): Decimal("-5"), ] }, - dtype=Decimal64Dtype(7, 4), + dtype=Decimal32Dtype(7, 4), index=[0, 1, 0, 1, 0, 1], ), ), @@ -1442,7 +1459,7 @@ def test_concat_decimal_series(ltype, rtype): Decimal("-48"), ] }, - dtype=Decimal64Dtype(5, 2), + dtype=Decimal32Dtype(5, 2), index=[0, 1, 0, 1, 0, 1], ), ), @@ -1464,7 +1481,7 @@ def test_concat_decimal_series(ltype, rtype): Decimal("-49.25"), ] }, - dtype=Decimal64Dtype(9, 4), + dtype=Decimal32Dtype(9, 4), index=[0, 1, 0, 1, 0, 1], ), ), @@ -1486,7 +1503,29 @@ def test_concat_decimal_series(ltype, rtype): Decimal("-31.945"), ] }, - dtype=Decimal64Dtype(9, 4), + dtype=Decimal32Dtype(9, 4), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.DataFrame( + {"val": [Decimal("95633.24"), Decimal("236.633")]}, + dtype=Decimal128Dtype(19, 4), + ), + gd.DataFrame({"val": [5393, -95832]}, dtype="int64"), + gd.DataFrame({"val": [-29.234, -31.945]}, dtype="float64"), + gd.DataFrame( + { + "val": [ + Decimal("95633.24"), + Decimal("236.633"), + Decimal("5393"), + Decimal("-95832"), + Decimal("-29.234"), + Decimal("-31.945"), + ] + }, + dtype=Decimal128Dtype(19, 4), index=[0, 1, 0, 1, 0, 1], ), ), @@ -1538,7 +1577,7 @@ def test_concat_decimal_numeric_dataframe(df1, df2, df3, expected): Decimal("593"), Decimal("-702"), ], - dtype=Decimal64Dtype(5, 2), + dtype=Decimal32Dtype(5, 2), index=[0, 1, 0, 1, 0, 1], ), ), @@ -1558,7 +1597,7 @@ def test_concat_decimal_numeric_dataframe(df1, df2, df3, expected): Decimal("5299.262"), Decimal("-2049.25"), ], - dtype=Decimal64Dtype(9, 4), + dtype=Decimal32Dtype(9, 4), index=[0, 1, 0, 1, 0, 1], ), ), @@ -1578,7 +1617,33 @@ def test_concat_decimal_numeric_dataframe(df1, df2, df3, expected): Decimal("-40.292"), Decimal("49202.953"), ], - dtype=Decimal64Dtype(9, 4), + dtype=Decimal32Dtype(9, 4), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.Series( + [Decimal("492.204"), Decimal("-72824.455")], + dtype=Decimal64Dtype(10, 4), + ), + gd.Series( + [Decimal("8438"), Decimal("-27462")], + dtype=Decimal32Dtype(9, 4), + ), + gd.Series( + [Decimal("-40.292"), Decimal("49202.953")], + dtype=Decimal128Dtype(19, 4), + ), + gd.Series( + [ + Decimal("492.204"), + Decimal("-72824.455"), + Decimal("8438"), + Decimal("-27462"), + Decimal("-40.292"), + Decimal("49202.953"), + ], + dtype=Decimal128Dtype(19, 4), index=[0, 1, 0, 1, 0, 1], ), ), diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 6171f20929d..61c3f428019 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. import array as arr import datetime @@ -2176,13 +2176,17 @@ def test_quantile(q, numeric_only): @pytest.mark.parametrize("q", [0.2, 1, 0.001, [0.5], [], [0.005, 0.8, 0.03]]) @pytest.mark.parametrize("interpolation", ["higher", "lower", "nearest"]) -def test_decimal_quantile(q, interpolation): +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], +) +def test_decimal_quantile(q, interpolation, decimal_type): data = ["244.8", "32.24", "2.22", "98.14", "453.23", "5.45"] gdf = cudf.DataFrame( {"id": np.random.randint(0, 10, size=len(data)), "val": data} ) gdf["id"] = gdf["id"].astype("float64") - gdf["val"] = gdf["val"].astype(cudf.Decimal64Dtype(7, 2)) + gdf["val"] = gdf["val"].astype(decimal_type(7, 2)) pdf = gdf.to_pandas() got = gdf.quantile(q, numeric_only=False, interpolation=interpolation) diff --git a/python/cudf/cudf/tests/test_dtypes.py b/python/cudf/cudf/tests/test_dtypes.py index 877cec24afa..356685c976e 100644 --- a/python/cudf/cudf/tests/test_dtypes.py +++ b/python/cudf/cudf/tests/test_dtypes.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import numpy as np import pandas as pd @@ -9,7 +9,9 @@ from cudf.core.column import ColumnBase from cudf.core.dtypes import ( CategoricalDtype, + Decimal32Dtype, Decimal64Dtype, + Decimal128Dtype, IntervalDtype, ListDtype, StructDtype, @@ -138,16 +140,28 @@ def test_struct_dtype_fields(fields): assert_eq(dt.fields, fields) -def test_decimal_dtype(): - dt = Decimal64Dtype(4, 2) +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], +) +def test_decimal_dtype_arrow_roundtrip(decimal_type): + dt = decimal_type(4, 2) assert dt.to_arrow() == pa.decimal128(4, 2) - assert dt == Decimal64Dtype.from_arrow(pa.decimal128(4, 2)) + assert dt == decimal_type.from_arrow(pa.decimal128(4, 2)) -def test_max_precision(): - Decimal64Dtype(scale=0, precision=18) +@pytest.mark.parametrize( + "decimal_type,max_precision", + [ + (cudf.Decimal32Dtype, 9), + (cudf.Decimal64Dtype, 18), + (cudf.Decimal128Dtype, 38), + ], +) +def test_max_precision(decimal_type, max_precision): + decimal_type(scale=0, precision=max_precision) with pytest.raises(ValueError): - Decimal64Dtype(scale=0, precision=19) + decimal_type(scale=0, precision=max_precision + 1) @pytest.mark.parametrize("fields", ["int64", "int32"]) @@ -180,7 +194,9 @@ def assert_column_array_dtype_equal(column: ColumnBase, array: pa.array): for i, child in enumerate(column.base_children) ] ) - elif isinstance(column.dtype, Decimal64Dtype): + elif isinstance( + column.dtype, (Decimal128Dtype, Decimal64Dtype, Decimal32Dtype) + ): return array.type.equals(column.dtype.to_arrow()) elif isinstance(column.dtype, CategoricalDtype): raise NotImplementedError() diff --git a/python/cudf/cudf/tests/test_joining.py b/python/cudf/cudf/tests/test_joining.py index 2fb7393f5b4..69793dc1828 100644 --- a/python/cudf/cudf/tests/test_joining.py +++ b/python/cudf/cudf/tests/test_joining.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. import numpy as np import pandas as pd @@ -6,7 +6,7 @@ import cudf from cudf.core._compat import PANDAS_GE_120 -from cudf.core.dtypes import CategoricalDtype, Decimal64Dtype +from cudf.core.dtypes import CategoricalDtype, Decimal64Dtype, Decimal128Dtype from cudf.testing._utils import ( INTEGER_TYPES, NUMERIC_TYPES, @@ -1130,7 +1130,12 @@ def test_typecast_on_join_overflow_unsafe(dtypes): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(5, 2), Decimal64Dtype(7, 5), Decimal64Dtype(12, 7)], + [ + Decimal64Dtype(5, 2), + Decimal64Dtype(7, 5), + Decimal64Dtype(12, 7), + Decimal128Dtype(20, 5), + ], ) def test_decimal_typecast_inner(dtype): other_data = ["a", "b", "c", "d", "e"] @@ -1166,7 +1171,12 @@ def test_decimal_typecast_inner(dtype): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(7, 3), Decimal64Dtype(9, 5), Decimal64Dtype(14, 10)], + [ + Decimal64Dtype(7, 3), + Decimal64Dtype(9, 5), + Decimal64Dtype(14, 10), + Decimal128Dtype(21, 9), + ], ) def test_decimal_typecast_left(dtype): other_data = ["a", "b", "c", "d"] @@ -1203,7 +1213,12 @@ def test_decimal_typecast_left(dtype): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(7, 3), Decimal64Dtype(10, 5), Decimal64Dtype(18, 9)], + [ + Decimal64Dtype(7, 3), + Decimal64Dtype(10, 5), + Decimal64Dtype(18, 9), + Decimal128Dtype(22, 8), + ], ) def test_decimal_typecast_outer(dtype): other_data = ["a", "b", "c"] diff --git a/python/cudf/cudf/tests/test_orc.py b/python/cudf/cudf/tests/test_orc.py index dc176992434..44812f5aba4 100644 --- a/python/cudf/cudf/tests/test_orc.py +++ b/python/cudf/cudf/tests/test_orc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. import datetime import decimal @@ -15,7 +15,6 @@ import pytest import cudf -from cudf.core.dtypes import Decimal64Dtype from cudf.io.orc import ORCWriter from cudf.testing._utils import ( assert_eq, @@ -528,12 +527,6 @@ def test_orc_decimal_precision_fail(datadir): except pa.ArrowIOError as e: pytest.skip(".orc file is not found: %s" % e) - # Max precision supported is 18 (Decimal64Dtype limit) - # and the data has the precision 19. This test should be removed - # once Decimal128Dtype is introduced. - with pytest.raises(RuntimeError): - cudf.read_orc(file_path) - # Shouldn't cause failure if decimal column is not chosen to be read. pdf = orcfile.read(columns=["int"]).to_pandas() gdf = cudf.read_orc(file_path, columns=["int"]) @@ -790,12 +783,16 @@ def test_empty_string_columns(data): @pytest.mark.parametrize("scale", [-3, 0, 3]) -def test_orc_writer_decimal(tmpdir, scale): +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], +) +def test_orc_writer_decimal(tmpdir, scale, decimal_type): np.random.seed(0) fname = tmpdir / "decimal.orc" expected = cudf.DataFrame({"dec_val": gen_rand_series("i", 100)}) - expected["dec_val"] = expected["dec_val"].astype(Decimal64Dtype(7, scale)) + expected["dec_val"] = expected["dec_val"].astype(decimal_type(7, scale)) expected.to_orc(fname) diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 016ed1229f1..f239d88992a 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. import datetime import math @@ -633,29 +633,19 @@ def test_parquet_reader_spark_timestamps(datadir): def test_parquet_reader_spark_decimals(datadir): fname = datadir / "spark_decimal.parquet" - # expect = pd.read_parquet(fname) - with pytest.raises( - NotImplementedError, - match="Decimal type greater than Decimal64 is not yet supported", - ): - cudf.read_parquet(fname) - - # Convert the decimal dtype from PyArrow to float64 for comparison to cuDF - # This is because cuDF returns as float64 as it lacks an equivalent dtype - # expect = expect.apply(pd.to_numeric) + expect = pd.read_parquet(fname) + got = cudf.read_parquet(fname) - # np.testing.assert_allclose(expect, got) - # assert_eq(expect, got) + assert_eq(expect, got) @pytest.mark.parametrize("columns", [["a"], ["b", "a"], None]) -def test_parquet_reader_decimal128_error_validation(datadir, columns): +def test_parquet_reader_decimal128(datadir, columns): fname = datadir / "nested_decimal128_file.parquet" - with pytest.raises( - NotImplementedError, - match="Decimal type greater than Decimal64 is not yet supported", - ): - cudf.read_parquet(fname, columns=columns) + got = cudf.read_parquet(fname, columns=columns) + expect = cudf.read_parquet(fname, columns=columns) + + assert_eq(expect, got) def test_parquet_reader_microsecond_timestamps(datadir): @@ -2264,12 +2254,15 @@ def test_parquet_writer_nested(tmpdir, data): assert_eq(expect, got) -def test_parquet_writer_decimal(tmpdir): - from cudf.core.dtypes import Decimal64Dtype +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], +) +def test_parquet_writer_decimal(tmpdir, decimal_type): gdf = cudf.DataFrame({"val": [0.00, 0.01, 0.02]}) - gdf["dec_val"] = gdf["val"].astype(Decimal64Dtype(7, 2)) + gdf["dec_val"] = gdf["val"].astype(decimal_type(7, 2)) fname = tmpdir.join("test_parquet_writer_decimal.parquet") gdf.to_parquet(fname) @@ -2313,10 +2306,12 @@ def test_parquet_writer_nulls_pandas_read(tmpdir, pdf): assert_eq(gdf.to_pandas(nullable=nullable), got) -def test_parquet_decimal_precision(tmpdir): - df = cudf.DataFrame({"val": ["3.5", "4.2"]}).astype( - cudf.Decimal64Dtype(5, 2) - ) +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], +) +def test_parquet_decimal_precision(tmpdir, decimal_type): + df = cudf.DataFrame({"val": ["3.5", "4.2"]}).astype(decimal_type(5, 2)) assert df.val.dtype.precision == 5 fname = tmpdir.join("decimal_test.parquet") diff --git a/python/cudf/cudf/tests/test_reductions.py b/python/cudf/cudf/tests/test_reductions.py index 4ed6448de50..40add502309 100644 --- a/python/cudf/cudf/tests/test_reductions.py +++ b/python/cudf/cudf/tests/test_reductions.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from __future__ import division, print_function @@ -12,7 +12,7 @@ import cudf from cudf import Series -from cudf.core.dtypes import Decimal64Dtype +from cudf.core.dtypes import Decimal32Dtype, Decimal64Dtype, Decimal128Dtype from cudf.testing import _utils as utils from cudf.testing._utils import NUMERIC_TYPES, assert_eq, gen_rand @@ -53,10 +53,17 @@ def test_sum_string(): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(6, 3), Decimal64Dtype(10, 6), Decimal64Dtype(16, 7)], + [ + Decimal64Dtype(6, 3), + Decimal64Dtype(10, 6), + Decimal64Dtype(16, 7), + Decimal32Dtype(6, 3), + Decimal128Dtype(20, 7), + ], ) @pytest.mark.parametrize("nelem", params_sizes) def test_sum_decimal(dtype, nelem): + np.random.seed(0) data = [str(x) for x in gen_rand("int64", nelem) / 100] expected = pd.Series([Decimal(x) for x in data]).sum() @@ -89,9 +96,16 @@ def test_product(dtype, nelem): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(6, 2), Decimal64Dtype(8, 4), Decimal64Dtype(10, 5)], + [ + Decimal64Dtype(6, 2), + Decimal64Dtype(8, 4), + Decimal64Dtype(10, 5), + Decimal32Dtype(6, 2), + Decimal128Dtype(20, 5), + ], ) def test_product_decimal(dtype): + np.random.seed(0) data = [str(x) for x in gen_rand("int8", 3) / 10] expected = pd.Series([Decimal(x) for x in data]).product() @@ -131,9 +145,16 @@ def test_sum_of_squares(dtype, nelem): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(6, 2), Decimal64Dtype(8, 4), Decimal64Dtype(10, 5)], + [ + Decimal64Dtype(6, 2), + Decimal64Dtype(8, 4), + Decimal64Dtype(10, 5), + Decimal128Dtype(20, 7), + Decimal32Dtype(6, 2), + ], ) def test_sum_of_squares_decimal(dtype): + np.random.seed(0) data = [str(x) for x in gen_rand("int8", 3) / 10] expected = pd.Series([Decimal(x) for x in data]).pow(2).sum() @@ -156,10 +177,17 @@ def test_min(dtype, nelem): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(6, 3), Decimal64Dtype(10, 6), Decimal64Dtype(16, 7)], + [ + Decimal64Dtype(6, 3), + Decimal64Dtype(10, 6), + Decimal64Dtype(16, 7), + Decimal32Dtype(6, 3), + Decimal128Dtype(20, 7), + ], ) @pytest.mark.parametrize("nelem", params_sizes) def test_min_decimal(dtype, nelem): + np.random.seed(0) data = [str(x) for x in gen_rand("int64", nelem) / 100] expected = pd.Series([Decimal(x) for x in data]).min() @@ -182,10 +210,17 @@ def test_max(dtype, nelem): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(6, 3), Decimal64Dtype(10, 6), Decimal64Dtype(16, 7)], + [ + Decimal64Dtype(6, 3), + Decimal64Dtype(10, 6), + Decimal64Dtype(16, 7), + Decimal32Dtype(6, 3), + Decimal128Dtype(20, 7), + ], ) @pytest.mark.parametrize("nelem", params_sizes) def test_max_decimal(dtype, nelem): + np.random.seed(0) data = [str(x) for x in gen_rand("int64", nelem) / 100] expected = pd.Series([Decimal(x) for x in data]).max() diff --git a/python/cudf/cudf/tests/test_replace.py b/python/cudf/cudf/tests/test_replace.py index 2e7936feeae..90429945cc5 100644 --- a/python/cudf/cudf/tests/test_replace.py +++ b/python/cudf/cudf/tests/test_replace.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import re from decimal import Decimal @@ -8,7 +8,7 @@ import pytest import cudf -from cudf.core.dtypes import Decimal64Dtype +from cudf.core.dtypes import Decimal32Dtype, Decimal64Dtype, Decimal128Dtype from cudf.testing._utils import ( INTEGER_TYPES, NUMERIC_TYPES, @@ -350,7 +350,7 @@ def test_fillna_method_numerical(data, container, data_dtype, method, inplace): Decimal64Dtype(7, 2) ), cudf.Series(["-74.56", None, "-23.73", "34.55", "2.89", None]).astype( - Decimal64Dtype(7, 2) + Decimal32Dtype(7, 2) ), cudf.Series( ["85.955", np.nan, "-3.243", np.nan, "29.492", np.nan] @@ -361,6 +361,9 @@ def test_fillna_method_numerical(data, container, data_dtype, method, inplace): cudf.Series( [np.nan, "55.2498", np.nan, "-5.2965", "-28.9423", np.nan] ).astype(Decimal64Dtype(10, 4)), + cudf.Series( + ["2.964", None, "54347.432", "-989.330", None, "56.444"] + ).astype(Decimal128Dtype(20, 7)), ], ) @pytest.mark.parametrize( diff --git a/python/cudf/cudf/tests/test_scalar.py b/python/cudf/cudf/tests/test_scalar.py index a8b62710e0e..e8382681820 100644 --- a/python/cudf/cudf/tests/test_scalar.py +++ b/python/cudf/cudf/tests/test_scalar.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. + import datetime import datetime as dt import re @@ -22,6 +24,8 @@ cudf.Decimal64Dtype(1, 1), cudf.Decimal64Dtype(4, 2), cudf.Decimal64Dtype(4, -2), + cudf.Decimal32Dtype(3, 1), + cudf.Decimal128Dtype(28, 3), ] SCALAR_VALUES = [ @@ -145,8 +149,12 @@ def test_scalar_device_initialization(value): @pytest.mark.parametrize("value", DECIMAL_VALUES) -def test_scalar_device_initialization_decimal(value): - dtype = cudf.Decimal64Dtype._from_decimal(value) +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], +) +def test_scalar_device_initialization_decimal(value, decimal_type): + dtype = decimal_type._from_decimal(value) column = cudf.Series([str(value)]).astype(dtype)._column dev_slr = get_element(column, 0) @@ -199,7 +207,7 @@ def test_null_scalar(dtype): assert s.value is cudf.NA assert s.dtype == ( cudf.dtype(dtype) - if not isinstance(dtype, cudf.Decimal64Dtype) + if not isinstance(dtype, cudf.core.dtypes.DecimalDtype) else dtype ) assert s.is_valid() is False @@ -250,6 +258,12 @@ def test_scalar_dtype_and_validity(dtype): (Decimal(1), cudf.Decimal64Dtype(1, 0), Decimal("1")), (Decimal("1.1"), cudf.Decimal64Dtype(2, 1), Decimal("1.1")), (Decimal("1.1"), cudf.Decimal64Dtype(4, 3), Decimal("1.100")), + (Decimal("41.123"), cudf.Decimal32Dtype(5, 3), Decimal("41.123")), + ( + Decimal("41345435344353535344373628492731234.123"), + cudf.Decimal128Dtype(38, 3), + Decimal("41345435344353535344373628492731234.123"), + ), (Decimal("1.11"), cudf.Decimal64Dtype(2, 2), pa.lib.ArrowInvalid), ], ) @@ -335,18 +349,25 @@ def test_scalar_invalid_implicit_conversion(cls, dtype): @pytest.mark.parametrize("value", SCALAR_VALUES + DECIMAL_VALUES) -def test_device_scalar_direct_construction(value): +@pytest.mark.parametrize( + "decimal_type", + [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], +) +def test_device_scalar_direct_construction(value, decimal_type): value = cudf.utils.utils.to_cudf_compatible_scalar(value) + dtype = ( value.dtype if not isinstance(value, Decimal) - else cudf.Decimal64Dtype._from_decimal(value) + else decimal_type._from_decimal(value) ) s = cudf._lib.scalar.DeviceScalar(value, dtype) assert s.value == value or np.isnan(s.value) and np.isnan(value) - if isinstance(dtype, cudf.Decimal64Dtype): + if isinstance( + dtype, (cudf.Decimal64Dtype, cudf.Decimal128Dtype, cudf.Decimal32Dtype) + ): assert s.dtype.precision == dtype.precision assert s.dtype.scale == dtype.scale elif dtype.char == "U": diff --git a/python/cudf/cudf/tests/test_scan.py b/python/cudf/cudf/tests/test_scan.py index 741a9f45d09..4cbc2197cfd 100644 --- a/python/cudf/cudf/tests/test_scan.py +++ b/python/cudf/cudf/tests/test_scan.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. + from itertools import product import numpy as np @@ -5,7 +7,7 @@ import pytest import cudf -from cudf.core.dtypes import Decimal64Dtype +from cudf.core.dtypes import Decimal32Dtype, Decimal64Dtype, Decimal128Dtype from cudf.testing._utils import ( INTEGER_TYPES, NUMERIC_TYPES, @@ -69,7 +71,13 @@ def test_cumsum_masked(): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(8, 4), Decimal64Dtype(10, 5), Decimal64Dtype(12, 7)], + [ + Decimal64Dtype(8, 4), + Decimal64Dtype(10, 5), + Decimal64Dtype(12, 7), + Decimal32Dtype(8, 5), + Decimal128Dtype(13, 6), + ], ) def test_cumsum_decimal(dtype): data = ["243.32", "48.245", "-7234.298", np.nan, "-467.2"] @@ -126,7 +134,13 @@ def test_cummin_masked(): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(8, 4), Decimal64Dtype(11, 6), Decimal64Dtype(14, 7)], + [ + Decimal64Dtype(8, 4), + Decimal64Dtype(11, 6), + Decimal64Dtype(14, 7), + Decimal32Dtype(8, 4), + Decimal128Dtype(11, 6), + ], ) def test_cummin_decimal(dtype): data = ["8394.294", np.nan, "-9940.444", np.nan, "-23.928"] @@ -183,7 +197,13 @@ def test_cummax_masked(): @pytest.mark.parametrize( "dtype", - [Decimal64Dtype(8, 4), Decimal64Dtype(11, 6), Decimal64Dtype(14, 7)], + [ + Decimal64Dtype(8, 4), + Decimal64Dtype(11, 6), + Decimal64Dtype(14, 7), + Decimal32Dtype(8, 4), + Decimal128Dtype(11, 6), + ], ) def test_cummax_decimal(dtype): data = [np.nan, "54.203", "8.222", "644.32", "-562.272"] diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index cc7be02a024..75cf2e6c892 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. import json import re @@ -229,9 +229,13 @@ def test_string_astype(dtype): ([], 0, 5), ], ) -def test_string_to_decimal(data, scale, precision): +@pytest.mark.parametrize( + "decimal_dtype", + [cudf.Decimal128Dtype, cudf.Decimal64Dtype, cudf.Decimal32Dtype], +) +def test_string_to_decimal(data, scale, precision, decimal_dtype): gs = cudf.Series(data, dtype="str") - fp = gs.astype(cudf.Decimal64Dtype(scale=scale, precision=precision)) + fp = gs.astype(decimal_dtype(scale=scale, precision=precision)) got = fp.astype("str") assert_eq(gs, got) @@ -256,7 +260,11 @@ def test_string_empty_to_decimal(): ([], 0, 5), ], ) -def test_string_from_decimal(data, scale, precision): +@pytest.mark.parametrize( + "decimal_dtype", + [cudf.Decimal128Dtype, cudf.Decimal32Dtype, cudf.Decimal64Dtype], +) +def test_string_from_decimal(data, scale, precision, decimal_dtype): decimal_data = [] for d in data: if d is None: @@ -264,11 +272,10 @@ def test_string_from_decimal(data, scale, precision): else: decimal_data.append(Decimal(d)) fp = cudf.Series( - decimal_data, - dtype=cudf.Decimal64Dtype(scale=scale, precision=precision), + decimal_data, dtype=decimal_dtype(scale=scale, precision=precision), ) gs = fp.astype("str") - got = gs.astype(cudf.Decimal64Dtype(scale=scale, precision=precision)) + got = gs.astype(decimal_dtype(scale=scale, precision=precision)) assert_eq(fp, got) diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index 7142d0d710e..44bbb1b493d 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import datetime as dt from collections import namedtuple @@ -164,8 +164,20 @@ def _find_common_type_decimal(dtypes): lhs = max([dtype.precision - dtype.scale for dtype in dtypes]) # Combine to get the necessary precision and clip at the maximum # precision - p = min(cudf.Decimal64Dtype.MAX_PRECISION, s + lhs) - return cudf.Decimal64Dtype(p, s) + p = s + lhs + + if p > cudf.Decimal64Dtype.MAX_PRECISION: + return cudf.Decimal128Dtype( + min(cudf.Decimal128Dtype.MAX_PRECISION, p), s + ) + elif p > cudf.Decimal32Dtype.MAX_PRECISION: + return cudf.Decimal64Dtype( + min(cudf.Decimal64Dtype.MAX_PRECISION, p), s + ) + else: + return cudf.Decimal32Dtype( + min(cudf.Decimal32Dtype.MAX_PRECISION, p), s + ) def cudf_dtype_from_pydata_dtype(dtype): @@ -179,6 +191,8 @@ def cudf_dtype_from_pydata_dtype(dtype): return cudf.core.dtypes.Decimal32Dtype elif cudf.api.types.is_decimal64_dtype(dtype): return cudf.core.dtypes.Decimal64Dtype + elif cudf.api.types.is_decimal128_dtype(dtype): + return cudf.core.dtypes.Decimal128Dtype elif dtype in cudf._lib.types.SUPPORTED_NUMPY_TO_LIBCUDF_TYPES: return dtype.type @@ -210,7 +224,7 @@ def cudf_dtype_from_pa_type(typ): elif pa.types.is_struct(typ): return cudf.core.dtypes.StructDtype.from_arrow(typ) elif pa.types.is_decimal(typ): - return cudf.core.dtypes.Decimal64Dtype.from_arrow(typ) + return cudf.core.dtypes.Decimal128Dtype.from_arrow(typ) else: return cudf.api.types.pandas_dtype(typ.to_pandas_dtype()) @@ -586,8 +600,9 @@ def _can_cast(from_dtype, to_dtype): # TODO : Add precision & scale checking for # decimal types in future - if isinstance(from_dtype, cudf.core.dtypes.Decimal64Dtype): - if isinstance(to_dtype, cudf.core.dtypes.Decimal64Dtype): + + if isinstance(from_dtype, cudf.core.dtypes.DecimalDtype): + if isinstance(to_dtype, cudf.core.dtypes.DecimalDtype): return True elif isinstance(to_dtype, np.dtype): if to_dtype.kind in {"i", "f", "u", "U", "O"}: @@ -597,7 +612,7 @@ def _can_cast(from_dtype, to_dtype): elif isinstance(from_dtype, np.dtype): if isinstance(to_dtype, np.dtype): return np.can_cast(from_dtype, to_dtype) - elif isinstance(to_dtype, cudf.core.dtypes.Decimal64Dtype): + elif isinstance(to_dtype, cudf.core.dtypes.DecimalDtype): if from_dtype.kind in {"i", "f", "u", "U", "O"}: return True else: diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index cea384b9c11..2af7543e600 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. import decimal import functools @@ -58,7 +58,7 @@ def scalar_broadcast_to(scalar, size, dtype=None): if isinstance(scalar, decimal.Decimal): if dtype is None: - dtype = cudf.Decimal64Dtype._from_decimal(scalar) + dtype = cudf.Decimal128Dtype._from_decimal(scalar) out_col = column.column_empty(size, dtype=dtype) if out_col.size != 0: