diff --git a/python/cudf/cudf/core/buffer.py b/python/cudf/cudf/core/buffer.py index 0658927975f..63e99f34803 100644 --- a/python/cudf/cudf/core/buffer.py +++ b/python/cudf/cudf/core/buffer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from __future__ import annotations import functools @@ -123,16 +123,20 @@ def serialize(self) -> Tuple[dict, list]: header["constructor-kwargs"] = {} header["desc"] = self.__cuda_array_interface__.copy() header["desc"]["strides"] = (1,) + header["frame_count"] = 1 frames = [self] return header, frames @classmethod def deserialize(cls, header: dict, frames: list) -> Buffer: + assert ( + header["frame_count"] == 1 + ), "Only expecting to deserialize Buffer with a single frame." buf = cls(frames[0], **header["constructor-kwargs"]) if header["desc"]["shape"] != buf.__cuda_array_interface__["shape"]: raise ValueError( - f"Recieved a `Buffer` with the wrong size." + f"Received a `Buffer` with the wrong size." f" Expected {header['desc']['shape']}, " f"but got {buf.__cuda_array_interface__['shape']}" ) diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index f9bb7ea2f1a..56fc75f0451 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -2,19 +2,9 @@ from __future__ import annotations -import pickle from collections import abc from functools import cached_property -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Mapping, - Optional, - Sequence, - Tuple, - cast, -) +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, cast import numpy as np import pandas as pd @@ -685,53 +675,6 @@ def __contains__(self, item: ScalarLike) -> bool: return False return self._encode(item) in self.as_numerical - def serialize(self) -> Tuple[dict, list]: - header: Dict[Any, Any] = {} - frames = [] - header["type-serialized"] = pickle.dumps(type(self)) - header["dtype"], dtype_frames = self.dtype.serialize() - header["dtype_frames_count"] = len(dtype_frames) - frames.extend(dtype_frames) - header["data"], data_frames = self.codes.serialize() - header["data_frames_count"] = len(data_frames) - frames.extend(data_frames) - if self.mask is not None: - mask_header, mask_frames = self.mask.serialize() - header["mask"] = mask_header - frames.extend(mask_frames) - header["frame_count"] = len(frames) - return header, frames - - @classmethod - def deserialize(cls, header: dict, frames: list) -> CategoricalColumn: - n_dtype_frames = header["dtype_frames_count"] - dtype = CategoricalDtype.deserialize( - header["dtype"], frames[:n_dtype_frames] - ) - n_data_frames = header["data_frames_count"] - - column_type = pickle.loads(header["data"]["type-serialized"]) - data = column_type.deserialize( - header["data"], - frames[n_dtype_frames : n_dtype_frames + n_data_frames], - ) - mask = None - if "mask" in header: - mask = Buffer.deserialize( - header["mask"], [frames[n_dtype_frames + n_data_frames]] - ) - return cast( - CategoricalColumn, - column.build_column( - data=None, - dtype=dtype, - mask=mask, - children=( - column.build_column(data.base_data, dtype=data.dtype), - ), - ), - ) - def set_base_data(self, value): if value is not None: raise RuntimeError( diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 3fb71173178..e1d91e6d0c0 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -5,6 +5,7 @@ import pickle import warnings from functools import cached_property +from itertools import chain from types import SimpleNamespace from typing import ( Any, @@ -1037,10 +1038,29 @@ def unique(self) -> ColumnBase: return drop_duplicates([self], keep="first")[0] def serialize(self) -> Tuple[dict, list]: + # data model: + + # Serialization produces a nested metadata "header" and a flattened + # list of memoryviews/buffers that reference data (frames). Each + # header advertises a frame_count slot which indicates how many + # frames deserialization will consume. The class used to construct + # an object is named under the key "type-serialized" to match with + # Dask's serialization protocol (see + # distributed.protocol.serialize). Since column dtypes may either be + # cudf native or foreign some special-casing is required here for + # serialization. + header: Dict[Any, Any] = {} frames = [] header["type-serialized"] = pickle.dumps(type(self)) - header["dtype"] = self.dtype.str + try: + dtype, dtype_frames = self.dtype.serialize() + header["dtype"] = dtype + frames.extend(dtype_frames) + header["dtype-is-cudf-serialized"] = True + except AttributeError: + header["dtype"] = pickle.dumps(self.dtype) + header["dtype-is-cudf-serialized"] = False if self.data is not None: data_header, data_frames = self.data.serialize() @@ -1051,19 +1071,52 @@ def serialize(self) -> Tuple[dict, list]: mask_header, mask_frames = self.mask.serialize() header["mask"] = mask_header frames.extend(mask_frames) - + if self.children: + child_headers, child_frames = zip( + *(c.serialize() for c in self.children) + ) + header["subheaders"] = list(child_headers) + frames.extend(chain(*child_frames)) + header["size"] = self.size header["frame_count"] = len(frames) return header, frames @classmethod def deserialize(cls, header: dict, frames: list) -> ColumnBase: - dtype = header["dtype"] - data = Buffer.deserialize(header["data"], [frames[0]]) - mask = None + def unpack(header, frames) -> Tuple[Any, list]: + count = header["frame_count"] + klass = pickle.loads(header["type-serialized"]) + obj = klass.deserialize(header, frames[:count]) + return obj, frames[count:] + + assert header["frame_count"] == len(frames), ( + f"Deserialization expected {header['frame_count']} frames, " + f"but received {len(frames)}" + ) + if header["dtype-is-cudf-serialized"]: + dtype, frames = unpack(header["dtype"], frames) + else: + dtype = pickle.loads(header["dtype"]) + if "data" in header: + data, frames = unpack(header["data"], frames) + else: + data = None if "mask" in header: - mask = Buffer.deserialize(header["mask"], [frames[1]]) + mask, frames = unpack(header["mask"], frames) + else: + mask = None + children = [] + if "subheaders" in header: + for h in header["subheaders"]: + child, frames = unpack(h, frames) + children.append(child) + assert len(frames) == 0, "Deserialization did not consume all frames" return build_column( - data=data, dtype=dtype, mask=mask, size=header.get("size", None) + data=data, + dtype=dtype, + mask=mask, + size=header.get("size", None), + children=tuple(children), ) def unary_operator(self, unaryop: str): diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index d8ddb3d8d1a..69009106d15 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -2,7 +2,7 @@ import warnings from decimal import Decimal -from typing import Any, Sequence, Tuple, Union, cast +from typing import Any, Sequence, Union, cast import cupy as cp import numpy as np @@ -321,18 +321,6 @@ def to_arrow(self): buffers=[mask_buf, data_buf], ) - def serialize(self) -> Tuple[dict, list]: - header, frames = super().serialize() - header["dtype"] = self.dtype.serialize() - header["size"] = self.size - return header, frames - - @classmethod - def deserialize(cls, header: dict, frames: list) -> ColumnBase: - dtype = cudf.Decimal64Dtype.deserialize(*header["dtype"]) - header["dtype"] = dtype - return super().deserialize(header, frames) - @property def __cuda_array_interface__(self): raise NotImplementedError( diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index 2964378d114..30e418f0825 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -1,6 +1,5 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. -import pickle from functools import cached_property from typing import List, Optional, Sequence, Union @@ -28,7 +27,6 @@ is_list_dtype, is_scalar, ) -from cudf.core.buffer import Buffer from cudf.core.column import ColumnBase, as_column, column from cudf.core.column.methods import ColumnMethods, ParentType from cudf.core.dtypes import ListDtype @@ -166,64 +164,6 @@ def set_base_data(self, value): else: super().set_base_data(value) - def serialize(self): - header = {} - frames = [] - header["type-serialized"] = pickle.dumps(type(self)) - header["null_count"] = self.null_count - header["size"] = self.size - header["dtype"], dtype_frames = self.dtype.serialize() - header["dtype_frames_count"] = len(dtype_frames) - frames.extend(dtype_frames) - - sub_headers = [] - - for item in self.children: - sheader, sframes = item.serialize() - sub_headers.append(sheader) - frames.extend(sframes) - - if self.null_count > 0: - frames.append(self.mask) - - header["subheaders"] = sub_headers - header["frame_count"] = len(frames) - - return header, frames - - @classmethod - def deserialize(cls, header, frames): - - # Get null mask - if header["null_count"] > 0: - mask = Buffer(frames[-1]) - else: - mask = None - - # Deserialize dtype - dtype = pickle.loads(header["dtype"]["type-serialized"]).deserialize( - header["dtype"], frames[: header["dtype_frames_count"]] - ) - - # Deserialize child columns - children = [] - f = header["dtype_frames_count"] - for h in header["subheaders"]: - fcount = h["frame_count"] - child_frames = frames[f : f + fcount] - column_type = pickle.loads(h["type-serialized"]) - children.append(column_type.deserialize(h, child_frames)) - f += fcount - - # Materialize list column - return column.build_column( - data=None, - dtype=dtype, - mask=mask, - children=tuple(children), - size=header["size"], - ) - @property def __cuda_array_interface__(self): raise NotImplementedError( diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 0db7e7d9a27..70097f15372 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -2,14 +2,12 @@ from __future__ import annotations -import pickle import re import warnings from functools import cached_property from typing import ( TYPE_CHECKING, Any, - Dict, Optional, Sequence, Tuple, @@ -5336,56 +5334,6 @@ def to_pandas( pd_series.index = index return pd_series - def serialize(self) -> Tuple[dict, list]: - header: Dict[Any, Any] = {"null_count": self.null_count} - header["type-serialized"] = pickle.dumps(type(self)) - header["size"] = self.size - - frames = [] - sub_headers = [] - - for item in self.children: - sheader, sframes = item.serialize() - sub_headers.append(sheader) - frames.extend(sframes) - - if self.null_count > 0: - frames.append(self.mask) - - header["subheaders"] = sub_headers - header["frame_count"] = len(frames) - return header, frames - - @classmethod - def deserialize(cls, header: dict, frames: list) -> StringColumn: - size = header["size"] - if not isinstance(size, int): - size = pickle.loads(size) - - # Deserialize the mask, value, and offset frames - buffers = [Buffer(each_frame) for each_frame in frames] - - nbuf = None - if header["null_count"] > 0: - nbuf = buffers[2] - - children = [] - for h, b in zip(header["subheaders"], buffers[:2]): - column_type = pickle.loads(h["type-serialized"]) - children.append(column_type.deserialize(h, [b])) - - col = cast( - StringColumn, - column.build_column( - data=None, - dtype="str", - mask=nbuf, - children=tuple(children), - size=size, - ), - ) - return col - def can_cast_safely(self, to_dtype: Dtype) -> bool: to_dtype = cudf.dtype(to_dtype) diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 585e8b94e80..9991bad5a9e 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -1,8 +1,9 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. import decimal +import operator import pickle -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np import pandas as pd @@ -68,6 +69,50 @@ def dtype(arbitrary): ) +def _decode_type( + cls: Type, + header: dict, + frames: list, + is_valid_class: Callable[[Type, Type], bool] = operator.is_, +) -> Tuple[dict, list, Type]: + """Decode metadata-encoded type and check validity + + Parameters + ---------- + cls : type + class performing deserialization + header : dict + metadata for deserialization + frames : list + buffers containing data for deserialization + is_valid_class : Callable + function to call to check if the encoded class type is valid for + serialization by `cls` (default is to check type equality), called + as `is_valid_class(decoded_class, cls)`. + + Returns + ------- + tuple + Tuple of validated headers, frames, and the decoded class + constructor. + + Raises + ------ + AssertionError + if the number of frames doesn't match the count encoded in the + headers, or `is_valid_class` is not true. + """ + assert header["frame_count"] == len(frames), ( + f"Deserialization expected {header['frame_count']} frames, " + f"but received {len(frames)}." + ) + klass = pickle.loads(header["type-serialized"]) + assert is_valid_class( + klass, cls + ), f"Header-encoded {klass=} does not match decoding {cls=}." + return header, frames, klass + + class _BaseDtype(ExtensionDtype, Serializable): # Base type for all cudf-specific dtypes pass @@ -169,11 +214,12 @@ def serialize(self): categories_header, categories_frames = self.categories.serialize() header["categories"] = categories_header frames.extend(categories_frames) - + header["frame_count"] = len(frames) return header, frames @classmethod def deserialize(cls, header, frames): + header, frames, klass = _decode_type(cls, header, frames) ordered = header["ordered"] categories_header = header["categories"] categories_frames = frames @@ -181,7 +227,7 @@ def deserialize(cls, header, frames): categories = categories_type.deserialize( categories_header, categories_frames ) - return cls(categories=categories, ordered=ordered) + return klass(categories=categories, ordered=ordered) class ListDtype(_BaseDtype): @@ -254,19 +300,19 @@ def serialize(self) -> Tuple[dict, list]: header["element-type"], frames = self.element_type.serialize() else: header["element-type"] = self.element_type - + header["frame_count"] = len(frames) return header, frames @classmethod def deserialize(cls, header: dict, frames: list): + header, frames, klass = _decode_type(cls, header, frames) if isinstance(header["element-type"], dict): element_type = pickle.loads( header["element-type"]["type-serialized"] ).deserialize(header["element-type"], frames) else: element_type = header["element-type"] - - return cls(element_type=element_type) + return klass(element_type=element_type) class StructDtype(_BaseDtype): @@ -325,7 +371,7 @@ def serialize(self) -> Tuple[dict, list]: frames: List[Buffer] = [] - fields = {} + fields: Dict[str, Union[bytes, Tuple[Any, Tuple[int, int]]]] = {} for k, dtype in self.fields.items(): if isinstance(dtype, _BaseDtype): @@ -336,13 +382,14 @@ def serialize(self) -> Tuple[dict, list]: ) frames.extend(dtype_frames) else: - fields[k] = dtype + fields[k] = pickle.dumps(dtype) header["fields"] = fields - + header["frame_count"] = len(frames) return header, frames @classmethod def deserialize(cls, header: dict, frames: list): + header, frames, klass = _decode_type(cls, header, frames) fields = {} for k, dtype in header["fields"].items(): if isinstance(dtype, tuple): @@ -354,7 +401,7 @@ def deserialize(cls, header: dict, frames: list): frames[start:stop], ) else: - fields[k] = dtype + fields[k] = pickle.loads(dtype) return cls(fields) @@ -452,13 +499,18 @@ def serialize(self) -> Tuple[dict, list]: "type-serialized": pickle.dumps(type(self)), "precision": self.precision, "scale": self.scale, + "frame_count": 0, }, [], ) @classmethod def deserialize(cls, header: dict, frames: list): - return cls(header["precision"], header["scale"]) + header, frames, klass = _decode_type( + cls, header, frames, is_valid_class=issubclass + ) + klass = pickle.loads(header["type-serialized"]) + return klass(header["precision"], header["scale"]) def __eq__(self, other: Dtype) -> bool: if other is self: @@ -531,6 +583,21 @@ def from_pandas(cls, pd_dtype: pd.IntervalDtype) -> "IntervalDtype": subtype=pd_dtype.subtype ) # TODO: needs `closed` when we upgrade Pandas + def serialize(self) -> Tuple[dict, list]: + header = { + "type-serialized": pickle.dumps(type(self)), + "fields": pickle.dumps((self.subtype, self.closed)), + "frame_count": 0, + } + return header, [] + + @classmethod + def deserialize(cls, header: dict, frames: list): + header, frames, klass = _decode_type(cls, header, frames) + klass = pickle.loads(header["type-serialized"]) + subtype, closed = pickle.loads(header["fields"]) + return klass(subtype, closed=closed) + def is_categorical_dtype(obj): """Check whether an array-like or dtype is of the Categorical dtype. diff --git a/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_0.16.pkl b/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_0.16.pkl index 30e31487a82..97c745c1dd0 100644 Binary files a/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_0.16.pkl and b/python/cudf/cudf/tests/data/pkl/stringColumnWithRangeIndex_cudf_0.16.pkl differ diff --git a/python/cudf/cudf/tests/test_serialize.py b/python/cudf/cudf/tests/test_serialize.py index 440dcf527ca..b7d679e95d5 100644 --- a/python/cudf/cudf/tests/test_serialize.py +++ b/python/cudf/cudf/tests/test_serialize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright (c) 2018-2022, NVIDIA CORPORATION. import pickle @@ -23,6 +23,23 @@ lambda: cudf.Series([1, 2, 3])[:2]._column, lambda: cudf.Series(["a", "bb", "ccc"]), lambda: cudf.Series(["a", None, "ccc"]), + lambda: cudf.Series( + [ + {"a": ({"b": [1, 2, 3], "c": [4, 5, 6]}, {"d": [2, 4, 6]})}, + {"e": ({"b": [0, 2, 4], "c": [-1, -2, -3]}, {"d": [1, 1, 1]})}, + ] + ), + lambda: cudf.Series( + [ + 14.12302, + 97938.2, + np.nan, + 0.0, + -8.302014, + np.nan, + -112.2314, + ] + ).astype(cudf.Decimal64Dtype(7, 2)), lambda: cudf.DataFrame({"x": [1, 2, 3]}), lambda: cudf.DataFrame({"x": [1, 2, 3], "y": [1.0, None, 3.0]}), lambda: cudf.DataFrame( @@ -35,11 +52,47 @@ {"x": ["a", "bb", "ccc"], "y": [1.0, None, 3.0]}, index=[1, None, 3], ), - pd._testing.makeTimeDataFrame, + pd._testing.makeBoolIndex, + pd._testing.makeCategoricalIndex, + lambda: pd._testing.makeCustomDataframe(3, 4), + lambda: pd._testing.makeCustomIndex(2, 5), + pd._testing.makeDataFrame, + pd._testing.makeDateIndex, + pd._testing.makeFloatIndex, + pd._testing.makeFloatSeries, + pd._testing.makeIntIndex, + pd._testing.makeIntervalIndex, + pd._testing.makeMissingDataframe, pd._testing.makeMixedDataFrame, + pd._testing.makeMultiIndex, + lambda: pd._testing.makeNumericIndex(dtype=np.float64), + pd._testing.makeObjectSeries, + pytest.param( + pd._testing.makePeriodFrame, + marks=pytest.mark.xfail( + reason="Periods not supported in cudf", raises=RuntimeError + ), + ), + pytest.param( + pd._testing.makePeriodIndex, + marks=pytest.mark.xfail( + reason="Periods not supported in cudf", raises=RuntimeError + ), + ), + pytest.param( + pd._testing.makePeriodSeries, + marks=pytest.mark.xfail( + reason="Periods not supported in cudf", raises=RuntimeError + ), + ), + pd._testing.makeRangeIndex, + pd._testing.makeStringIndex, + pd._testing.makeStringSeries, pd._testing.makeTimeDataFrame, - # pd._testing.makeMissingDataframe, # Problem in distributed - # pd._testing.makeMultiIndex, # Indices not serialized on device + pd._testing.makeTimeSeries, + pd._testing.makeTimedeltaIndex, + pd._testing.makeUIntIndex, + pd._testing.makeUnicodeIndex, ], ) @pytest.mark.parametrize("to_host", [True, False]) @@ -64,13 +117,25 @@ def test_serialize(df, to_host): elif hasattr(df, "_cols"): assert ndevice >= len(df._data) else: - assert ndevice > 0 + # If there are frames, something should be on the device + assert ndevice > 0 or not frames typ = type(a) b = typ.deserialize(header, frames) assert_eq(a, b) +def test_serialize_dtype_error_checking(): + dtype = cudf.IntervalDtype("float", "right") + header, frames = dtype.serialize() + with pytest.raises(AssertionError): + # Invalid number of frames + type(dtype).deserialize(header, [None] * (header["frame_count"] + 1)) + with pytest.raises(AssertionError): + # mismatching class + cudf.StructDtype.deserialize(header, frames) + + def test_serialize_dataframe(): df = cudf.DataFrame() df["a"] = np.arange(100)