Skip to content

Commit

Permalink
Adds serialization of Decimal Columns and dtypes (#8041)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller authored Apr 26, 2021
1 parent 3ace5ec commit 94afdda
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 12 deletions.
4 changes: 3 additions & 1 deletion python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,9 @@ def deserialize(cls, header: dict, frames: list) -> ColumnBase:
mask = None
if "mask" in header:
mask = Buffer.deserialize(header["mask"], [frames[1]])
return build_column(data=data, dtype=dtype, mask=mask)
return build_column(
data=data, dtype=dtype, mask=mask, size=header.get("size", None)
)

def binary_operator(
self, op: builtins.str, other: BinaryOperand, reflect: bool = False
Expand Down
18 changes: 15 additions & 3 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

from decimal import Decimal
from typing import cast, Any, Sequence, Union
from numbers import Number
from typing import Any, Sequence, Tuple, Union, cast

import cupy as cp
import numpy as np
import pyarrow as pa
from pandas.api.types import is_integer_dtype
from numbers import Number

import cudf
from cudf import _lib as libcudf
from cudf._lib.quantiles import quantile as cpp_quantile
from cudf._lib.strings.convert.convert_fixed_point import (
from_decimal as cpp_from_decimal,
)
from cudf._lib.quantiles import quantile as cpp_quantile
from cudf._typing import Dtype
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, as_column
Expand Down Expand Up @@ -229,6 +229,18 @@ def fillna(
)
return self._copy_type_metadata(result)

def serialize(self) -> Tuple[dict, list]:
header, frames = super().serialize()
header["dtype"] = self.dtype.serialize()
header["size"] = self.size
return header, frames

@classmethod
def deserialize(cls, header: dict, frames: list) -> ColumnBase:
dtype = cudf.Decimal64Dtype.deserialize(*header["dtype"])
header["dtype"] = dtype
return super().deserialize(header, frames)


def _binop_scale(l_dtype, r_dtype, op):
# This should at some point be hooked up to libcudf's
Expand Down
13 changes: 12 additions & 1 deletion python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import decimal
import pickle
from typing import Any, Optional
from typing import Any, Optional, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -268,6 +268,10 @@ def __init__(self, precision, scale=0):
self._validate(precision, scale)
self._typ = pa.decimal128(precision, scale)

@property
def str(self):
return f"decimal64({self.precision}, {self.scale})"

@property
def precision(self):
return self._typ.precision
Expand Down Expand Up @@ -325,6 +329,13 @@ def _from_decimal(cls, decimal):
precision = max(len(metadata.digits), -metadata.exponent)
return cls(precision, -metadata.exponent)

def serialize(self) -> Tuple[dict, list]:
return {"precision": self.precision, "scale": self.scale}, []

@classmethod
def deserialize(cls, header: dict, frames: list):
return cls(header["precision"], header["scale"])


class IntervalDtype(StructDtype):
name = "interval"
Expand Down
7 changes: 1 addition & 6 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FLOAT_TYPES,
INTEGER_TYPES,
NUMERIC_TYPES,
_decimal_series,
assert_eq,
)

Expand Down Expand Up @@ -204,12 +205,6 @@ def test_typecast_from_decimal(data, from_dtype, to_dtype):
assert_eq(got.dtype, expected.dtype)


def _decimal_series(input, dtype):
return cudf.Series(
[x if x is None else Decimal(x) for x in input], dtype=dtype,
)


@pytest.mark.parametrize(
"args",
[
Expand Down
40 changes: 39 additions & 1 deletion python/cudf/cudf/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import cudf
from cudf.tests import utils
from cudf.tests.utils import assert_eq
from cudf.tests.utils import _decimal_series, assert_eq


@pytest.mark.parametrize(
Expand Down Expand Up @@ -289,6 +289,44 @@ def test_serialize_list_columns(data):
assert_eq(recreated, df)


@pytest.mark.parametrize(
"data",
[
{
"a": _decimal_series(
["1", "2", "3"], dtype=cudf.Decimal64Dtype(1, 0)
)
},
{
"a": _decimal_series(
["1", "2", "3"], dtype=cudf.Decimal64Dtype(1, 0)
),
"b": _decimal_series(
["1.0", "2.0", "3.0"], dtype=cudf.Decimal64Dtype(2, 1)
),
"c": _decimal_series(
["10.1", "20.2", "30.3"], dtype=cudf.Decimal64Dtype(3, 1)
),
},
{
"a": _decimal_series(
["1", None, "3"], dtype=cudf.Decimal64Dtype(1, 0)
),
"b": _decimal_series(
["1.0", "2.0", None], dtype=cudf.Decimal64Dtype(2, 1)
),
"c": _decimal_series(
[None, "20.2", "30.3"], dtype=cudf.Decimal64Dtype(3, 1)
),
},
],
)
def test_serialize_decimal_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)


def test_deserialize_cudf_0_16(datadir):
fname = datadir / "pkl" / "stringColumnWithRangeIndex_cudf_0.16.pkl"

Expand Down
7 changes: 7 additions & 0 deletions python/cudf/cudf/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from collections.abc import Mapping, Sequence
from contextlib import contextmanager
from decimal import Decimal

import cupy
import numpy as np
Expand Down Expand Up @@ -296,6 +297,12 @@ def gen_rand_series(dtype, size, **kwargs):
return cudf.Series(values)


def _decimal_series(input, dtype):
return cudf.Series(
[x if x is None else Decimal(x) for x in input], dtype=dtype,
)


@contextmanager
def does_not_raise():
yield
Expand Down

0 comments on commit 94afdda

Please sign in to comment.