Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DecimalBaseColumn and move as_decimal_column #9001

Merged
merged 4 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,9 +1008,7 @@ def as_string_column(

def as_decimal_column(
self, dtype: Dtype, **kwargs
) -> Union[
"cudf.core.column.Decimal32Column", "cudf.core.column.Decimal64Column"
]:
) -> Union["cudf.core.column.decimal.DecimalBaseColumn"]:
raise NotImplementedError

def as_decimal64_column(
Expand Down
45 changes: 25 additions & 20 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,30 @@
from .numerical_base import NumericalBaseColumn


class Decimal32Column(NumericalBaseColumn):
class DecimalBaseColumn(NumericalBaseColumn):
"""Base column for decimal64 and decimal32 columns
"""

dtype: Union[Decimal32Dtype, Decimal64Dtype]

def as_decimal_column(
self, dtype: Dtype, **kwargs
) -> Union["DecimalBaseColumn"]:
if (
isinstance(dtype, (Decimal64Dtype, Decimal32Dtype))
and dtype.scale < self.dtype.scale
):
warn(
"cuDF truncates when downcasting decimals to a lower scale. "
"To round, use Series.round() or DataFrame.round()."
)

if dtype == self.dtype:
return self
return libcudf.unary.cast(self, dtype)


class Decimal32Column(DecimalBaseColumn):
dtype: Decimal32Dtype

@classmethod
Expand Down Expand Up @@ -78,7 +101,7 @@ def to_arrow(self):
)


class Decimal64Column(NumericalBaseColumn):
class Decimal64Column(DecimalBaseColumn):
dtype: Decimal64Dtype

def __truediv__(self, other):
Expand Down Expand Up @@ -202,24 +225,6 @@ def _decimal_quantile(

return result._with_type_metadata(self.dtype)

def as_decimal_column(
self, dtype: Dtype, **kwargs
) -> Union[
"cudf.core.column.Decimal32Column", "cudf.core.column.Decimal64Column"
]:
if (
isinstance(dtype, Decimal64Dtype)
and dtype.scale < self.dtype.scale
):
warn(
"cuDF truncates when downcasting decimals to a lower scale. "
"To round, use Series.round() or DataFrame.round()."
)

if dtype == self.dtype:
return self
return libcudf.unary.cast(self, dtype)

def as_numerical_column(
self, dtype: Dtype, **kwargs
) -> "cudf.core.column.NumericalColumn":
Expand Down
36 changes: 29 additions & 7 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import cudf
from cudf.core.column import Decimal32Column, Decimal64Column, NumericalColumn
from cudf.core.dtypes import Decimal64Dtype
from cudf.core.dtypes import Decimal32Dtype, Decimal64Dtype
from cudf.testing._utils import (
FLOAT_TYPES,
INTEGER_TYPES,
Expand Down Expand Up @@ -164,21 +164,43 @@ def test_typecast_from_int_to_decimal(data, from_dtype, to_dtype):
)
@pytest.mark.parametrize(
"from_dtype",
[Decimal64Dtype(7, 2), Decimal64Dtype(11, 4), Decimal64Dtype(18, 10)],
[
Decimal64Dtype(7, 2),
Decimal64Dtype(11, 4),
Decimal64Dtype(18, 10),
Decimal32Dtype(7, 2),
Decimal32Dtype(5, 3),
Decimal32Dtype(9, 5),
],
)
@pytest.mark.parametrize(
"to_dtype",
[Decimal64Dtype(7, 2), Decimal64Dtype(18, 10), Decimal64Dtype(11, 4)],
[
Decimal64Dtype(7, 2),
Decimal64Dtype(18, 10),
Decimal64Dtype(11, 4),
Decimal32Dtype(7, 2),
Decimal32Dtype(9, 5),
Decimal32Dtype(5, 3),
],
)
def test_typecast_to_from_decimal(data, from_dtype, to_dtype):
got = data.astype(from_dtype)
if from_dtype.scale > to_dtype.MAX_PRECISION:
pytest.skip(
"This is supposed to overflow because the representation value in "
"the source exceeds the max representable in destination dtype."
)
s = data.astype(from_dtype)

pa_arr = got.to_arrow().cast(
pa_arr = s.to_arrow().cast(
pa.decimal128(to_dtype.precision, to_dtype.scale), safe=False
)
expected = cudf.Series(Decimal64Column.from_arrow(pa_arr))
if isinstance(to_dtype, Decimal32Dtype):
expected = cudf.Series(Decimal32Column.from_arrow(pa_arr))
elif isinstance(to_dtype, Decimal64Dtype):
expected = cudf.Series(Decimal64Column.from_arrow(pa_arr))

got = got.astype(to_dtype)
got = s.astype(to_dtype)

assert_eq(got, expected)

Expand Down