Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialization methods for List and StructDtype #8441

Merged
merged 10 commits into from
Jun 21, 2021
15 changes: 11 additions & 4 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,14 @@ def set_base_data(self, value):

def serialize(self):
header = {}
frames = []
header["type-serialized"] = pickle.dumps(type(self))
header["dtype"] = pickle.dumps(self.dtype)
header["null_count"] = self.null_count
header["size"] = self.size
header["dtype"], dtype_frames = self.dtype.serialize()
header["dtype_frames_count"] = len(dtype_frames)
frames.extend(dtype_frames)

frames = []
sub_headers = []

for item in self.children:
Expand All @@ -211,9 +213,14 @@ def deserialize(cls, header, frames):
else:
mask = None

# Deserialize dtype
dtype = pickle.loads(header["dtype"]["type-serialized"]).deserialize(
header["dtype"], frames[: header["dtype_frames_count"]]
)

# Deserialize child columns
children = []
f = 0
f = header["dtype_frames_count"]
for h in header["subheaders"]:
fcount = h["frame_count"]
child_frames = frames[f : f + fcount]
Expand All @@ -224,7 +231,7 @@ def deserialize(cls, header, frames):
# Materialize list column
return column.build_column(
data=None,
dtype=pickle.loads(header["dtype"]),
dtype=dtype,
mask=mask,
children=tuple(children),
size=header["size"],
Expand Down
80 changes: 76 additions & 4 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import decimal
import pickle
from typing import Any, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
Expand All @@ -17,9 +17,11 @@

import cudf
from cudf._typing import Dtype
from cudf.core.abc import Serializable
from cudf.core.buffer import Buffer


class _BaseDtype(ExtensionDtype):
class _BaseDtype(ExtensionDtype, Serializable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something seems weird about this. We're making all of our extension dtypes serializable, but I believe we end up needing to override serialize and deserialize for all of them (ListDtype, StructDtype, CategoricalDtype). To me that suggests either the parent class needs to be generalized to be able to do at least some of the common work between these child classes, or that this inheritance relationship just isn't quite right.

I am weakly -1 on doing this as part of this PR. I maybe it makes more sense to add the serialize/deserialize methods in this PR and then refactor the common code out either into Serializeable or something that goes in between Serializeable and _BaseDtype in a separate PR.

Copy link
Contributor

@shwina shwina Jun 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, Serializable was an abstract base class, which forced all derived classes to implement serialize and deserialize. For performance reasons, we disabled that and made it a regular class. Now, derived classes must implement serialize and deserialize, but that is "only" by convention.

That being said, there's still very much value in inheriting from Serializable, as we get the methods host_serialize, device_serialize, host_deserialize, device_deserialize "for free" by the inheritance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with @brandon-b-miller's objection to making this change, but not the reasons.

Serializable declares an interface, but leaves it up to subclasses to implement it. Whether or not certain subclasses (e.g. all dtypes) can share parts (or all) of that implementation isn't really relevant to whether or not the inheritance pattern makes sense. All that inheriting from Serializable does is indicate that if subclasses implement serialize and deserialize, it will be possible to do pickle.dumps(obj).

All of the *_(?:de)?serialize methods just exist to provide hooks into Serializable.__reduce_ex__, the method that actually enables serialization. My issue with using Serializable for dtype objects is that these hooks are all predicated on the assumption that a subclass of Serializable can be decomposed into some header of metadata a collection of frames, which isn't the case for dtypes. If you look at the contents of the methods implemented by Serializable, they're encoding a bunch of metadata that IMO isn't really appropriate for a dtype, but rather for typed memory buffers (e.g. the length of the array or whether it's stored in device memory).

That being the case, I think that it would be simpler and more appropriate to directly implement the pickling protocol (ideally via __getstate__ and __setstate__, but if not then via __reduce* methods) rather than trying to leverage Serializable. To @brandon-b-miller's point, if some of that logic can be shared between dtypes it would also be great to do that by implementing it at the level of _BaseDtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do see some merit to @brandon-b-miller's point of making subclasses of Serializable that generalize some of the common work that's happening in the serialization function, though I haven't really inspected those functions outside of the dtypes to see if there's a lot of intersection there - were you thinking something like SerializableDtype, SerializableFrame, etc...?

To @vyasr's point, I feel like implementing the pickling protocol for the dtypes themselves could result in redundant code, since it would essentially entail copying Serializable.__reduce_ex__ in _BaseDtype. Is there a downside to having host/device deserialization implemented for dtypes other than the fact that those functions aren't really appropriate?

Also feel like that scenario gives more motivation for making subclasses of Serializable, as we could have subclasses that include/exclude the functions we consider inappropriate for their derived classes (such as the host/device serialization).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Serializable is less about making objects picklable and more about serializing objects according to the Dask serialization protocol. The *serialize methods are absolutely required here in order for dtype objects to be able to be sent efficiently across the wire.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true though that most dtypes really are composed only of metadata. The exception being CategoricalDtype, which for compatibility with Pandas, encapsulates also a column of categories (residing on the device).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having read a little more I'm comfortable 👍 -ing here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I see now that we're registering Serializable's methods to dask.distributed in cudf/comm/serialize.py. It does seem like we could simplify the specifics of the serialization protocol for dtypes since they are (almost) entirely metadata and not data, but for now I think moving forward with this approach is fine for now.

# Base type for all cudf-specific dtypes
pass

Expand Down Expand Up @@ -111,12 +113,16 @@ def construct_from_string(self):

def serialize(self):
header = {}
frames = []
header["type-serialized"] = pickle.dumps(type(self))
header["ordered"] = self.ordered

frames = []

if self.categories is not None:
categories_header, categories_frames = self.categories.serialize()
header["categories"] = categories_header
frames.extend(categories_frames)

return header, frames

@classmethod
Expand Down Expand Up @@ -191,6 +197,30 @@ def __repr__(self):
def __hash__(self):
return hash(self._typ)

def serialize(self) -> Tuple[dict, list]:
header: Dict[str, Dtype] = {}
header["type-serialized"] = pickle.dumps(type(self))

frames = []

if isinstance(self.element_type, _BaseDtype):
header["element-type"], frames = self.element_type.serialize()
else:
header["element-type"] = self.element_type

return header, frames

@classmethod
def deserialize(cls, header: dict, frames: list):
if isinstance(header["element-type"], dict):
element_type = pickle.loads(
header["element-type"]["type-serialized"]
).deserialize(header["element-type"], frames)
else:
element_type = header["element-type"]

return cls(element_type=element_type)


class StructDtype(_BaseDtype):

Expand Down Expand Up @@ -242,6 +272,41 @@ def __repr__(self):
def __hash__(self):
return hash(self._typ)

def serialize(self) -> Tuple[dict, list]:
header: Dict[str, Any] = {}
header["type-serialized"] = pickle.dumps(type(self))

frames: List[Buffer] = []

fields = {}

for k, dtype in self.fields.items():
if isinstance(dtype, _BaseDtype):
dtype_header, dtype_frames = dtype.serialize()
fields[k] = (
dtype_header,
(len(frames), len(frames) + len(dtype_frames)),
)
frames.extend(dtype_frames)
else:
fields[k] = dtype
header["fields"] = fields

return header, frames

@classmethod
def deserialize(cls, header: dict, frames: list):
fields = {}
for k, dtype in header["fields"].items():
if isinstance(dtype, tuple):
dtype_header, (start, stop) = dtype
fields[k] = pickle.loads(
dtype_header["type-serialized"]
).deserialize(dtype_header, frames[start:stop],)
else:
fields[k] = dtype
return cls(fields)


class Decimal64Dtype(_BaseDtype):

Expand Down Expand Up @@ -337,7 +402,14 @@ def _from_decimal(cls, decimal):
return cls(precision, -metadata.exponent)

def serialize(self) -> Tuple[dict, list]:
return {"precision": self.precision, "scale": self.scale}, []
return (
{
"type-serialized": pickle.dumps(type(self)),
"precision": self.precision,
"scale": self.scale,
},
[],
)

@classmethod
def deserialize(cls, header: dict, frames: list):
Expand Down
22 changes: 22 additions & 0 deletions python/cudf/cudf/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,25 @@ def test_series_construction_with_nulls(input_obj, dtype):
got = cudf.Series(input_obj, dtype="category").to_pandas()

assert_eq(expect, got)


@pytest.mark.parametrize(
"data",
[
{"a": cudf.Series(["a", "b", "c", "a", "c", "b"]).astype("category")},
{
"a": cudf.Series(["a", "a", "b", "b"]).astype("category"),
"b": cudf.Series(["b", "b", "c", "c"]).astype("category"),
"c": cudf.Series(["c", "c", "a", "a"]).astype("category"),
},
{
"a": cudf.Series(["a", None, "b", "b"]).astype("category"),
"b": cudf.Series(["b", "b", None, "c"]).astype("category"),
"c": cudf.Series(["c", "c", "a", None]).astype("category"),
},
],
)
def test_serialize_categorical_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)
38 changes: 38 additions & 0 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,41 @@ def test_series_construction_with_nulls(input_obj):
got = cudf.Series(input_obj).to_arrow()

assert expect == got


@pytest.mark.parametrize(
"data",
[
{
"a": _decimal_series(
["1", "2", "3"], dtype=cudf.Decimal64Dtype(1, 0)
)
},
{
"a": _decimal_series(
["1", "2", "3"], dtype=cudf.Decimal64Dtype(1, 0)
),
"b": _decimal_series(
["1.0", "2.0", "3.0"], dtype=cudf.Decimal64Dtype(2, 1)
),
"c": _decimal_series(
["10.1", "20.2", "30.3"], dtype=cudf.Decimal64Dtype(3, 1)
),
},
{
"a": _decimal_series(
["1", None, "3"], dtype=cudf.Decimal64Dtype(1, 0)
),
"b": _decimal_series(
["1.0", "2.0", None], dtype=cudf.Decimal64Dtype(2, 1)
),
"c": _decimal_series(
[None, "20.2", "30.3"], dtype=cudf.Decimal64Dtype(3, 1)
),
},
],
)
def test_serialize_decimal_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)
20 changes: 20 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,23 @@ def test_construction_series_with_nulls(input_obj):
got = cudf.Series(input_obj).to_arrow()

assert expect == got


@pytest.mark.parametrize(
"data",
[
{"a": [[]]},
{"a": [[1, 2, None, 4]]},
{"a": [["cat", None, "dog"]]},
{
"a": [[1, 2, 3, None], [4, None, 5]],
"b": [None, ["fish", "bird"]],
"c": [[], []],
},
{"a": [[1, 2, 3, None], [4, None, 5], None, [6, 7]]},
],
)
def test_serialize_list_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)
60 changes: 1 addition & 59 deletions python/cudf/cudf/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import cudf
from cudf.tests import utils
from cudf.tests.utils import _decimal_series, assert_eq
from cudf.tests.utils import assert_eq


@pytest.mark.parametrize(
Expand Down Expand Up @@ -269,64 +269,6 @@ def test_serialize_string_check_buffer_sizes():
assert expect == got


@pytest.mark.parametrize(
"data",
[
{"a": [[]]},
{"a": [[1, 2, None, 4]]},
{"a": [["cat", None, "dog"]]},
{
"a": [[1, 2, 3, None], [4, None, 5]],
"b": [None, ["fish", "bird"]],
"c": [[], []],
},
{"a": [[1, 2, 3, None], [4, None, 5], None, [6, 7]]},
],
)
def test_serialize_list_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)


@pytest.mark.parametrize(
"data",
[
{
"a": _decimal_series(
["1", "2", "3"], dtype=cudf.Decimal64Dtype(1, 0)
)
},
{
"a": _decimal_series(
["1", "2", "3"], dtype=cudf.Decimal64Dtype(1, 0)
),
"b": _decimal_series(
["1.0", "2.0", "3.0"], dtype=cudf.Decimal64Dtype(2, 1)
),
"c": _decimal_series(
["10.1", "20.2", "30.3"], dtype=cudf.Decimal64Dtype(3, 1)
),
},
{
"a": _decimal_series(
["1", None, "3"], dtype=cudf.Decimal64Dtype(1, 0)
),
"b": _decimal_series(
["1.0", "2.0", None], dtype=cudf.Decimal64Dtype(2, 1)
),
"c": _decimal_series(
[None, "20.2", "30.3"], dtype=cudf.Decimal64Dtype(3, 1)
),
},
],
)
def test_serialize_decimal_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)


def test_deserialize_cudf_0_16(datadir):
fname = datadir / "pkl" / "stringColumnWithRangeIndex_cudf_0.16.pkl"

Expand Down
22 changes: 22 additions & 0 deletions python/cudf/cudf/tests/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2020, NVIDIA CORPORATION.

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
Expand Down Expand Up @@ -53,3 +54,24 @@ def test_series_construction_with_nulls(input_obj):
got = cudf.Series(input_obj).to_arrow()

assert expect == got


@pytest.mark.parametrize(
"fields",
[
{"a": np.dtype(np.int64)},
{"a": np.dtype(np.int64), "b": None},
{
"a": cudf.ListDtype(np.dtype(np.int64)),
"b": cudf.Decimal64Dtype(1, 0),
},
{
"a": cudf.ListDtype(cudf.StructDtype({"b": np.dtype(np.int64)})),
"b": cudf.ListDtype(cudf.ListDtype(np.dtype(np.int64))),
},
],
)
def test_serialize_struct_dtype(fields):
dtype = cudf.StructDtype(fields)
recreated = dtype.__class__.deserialize(*dtype.serialize())
assert recreated == dtype