From 6300e6d6724d6d9a1b25ad0635013db32de6eef7 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 5 May 2022 07:01:37 -0700 Subject: [PATCH] dtypes: Refactor deserialize error checking --- python/cudf/cudf/core/dtypes.py | 87 +++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 32 deletions(-) diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 42a1184fb43..8e01ec6ea2d 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -1,8 +1,9 @@ # Copyright (c) 2020-2022, NVIDIA CORPORATION. import decimal +import operator import pickle -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np import pandas as pd @@ -68,6 +69,52 @@ def dtype(arbitrary): ) +def _decode_type( + cls: Type, + header: dict, + frames: list, + is_valid_class: Callable[[Type, Type], bool] = operator.is_, +) -> Tuple[dict, list, Type]: + """Decode metadata-encoded type and check validity + + Parameters + ---------- + cls : type + class performing deserialization + header : dict + metadata for deserialization + frames : list + buffers containing data for deserialization + is_valid_class : Callable + function to call to check if the encoded class type is valid for + serialization by `cls` (default is to check type equality), called + as `is_valid_class(decoded_class, cls)`. + + Returns + ------- + tuple + Tuple of validated headers, frames, and the decoded class + constructor. + + Raises + ------ + AssertionError + if the number of frames doesn't match the count encoded in the + headers, or `is_valid_class` is not true. + """ + if header["frame_count"] != len(frames): + raise AssertionError( + f"Deserialization expected {header['frame_count']} frames, " + f"but received {len(frames)}." + ) + klass = pickle.loads(header["type-serialized"]) + if not is_valid_class(klass, cls): + raise AssertionError( + f"Header-encoded {klass=} does not match decoding {cls=}." + ) + return header, frames, klass + + class _BaseDtype(ExtensionDtype, Serializable): # Base type for all cudf-specific dtypes pass @@ -174,13 +221,7 @@ def serialize(self): @classmethod def deserialize(cls, header, frames): - assert header["frame_count"] == len( - frames - ), "Received unexpected number of frames." - klass = pickle.loads(header["type-serialized"]) - assert ( - klass == cls - ), "Header-encoded type does not match reconstructing type" + header, frames, klass = _decode_type(cls, header, frames) ordered = header["ordered"] categories_header = header["categories"] categories_frames = frames @@ -266,13 +307,7 @@ def serialize(self) -> Tuple[dict, list]: @classmethod def deserialize(cls, header: dict, frames: list): - assert header["frame_count"] == len( - frames - ), "Received unexpected number of frames." - klass = pickle.loads(header["type-serialized"]) - assert ( - klass == cls - ), "Header-encoded type does not match reconstructing type" + header, frames, klass = _decode_type(cls, header, frames) if isinstance(header["element-type"], dict): element_type = pickle.loads( header["element-type"]["type-serialized"] @@ -356,13 +391,7 @@ def serialize(self) -> Tuple[dict, list]: @classmethod def deserialize(cls, header: dict, frames: list): - assert header["frame_count"] == len( - frames - ), "Received unexpected number of frames." - klass = pickle.loads(header["type-serialized"]) - assert ( - klass == cls - ), "Header-encoded type does not match reconstructing type" + header, frames, klass = _decode_type(cls, header, frames) fields = {} for k, dtype in header["fields"].items(): if isinstance(dtype, tuple): @@ -479,9 +508,9 @@ def serialize(self) -> Tuple[dict, list]: @classmethod def deserialize(cls, header: dict, frames: list): - assert header["frame_count"] == len( - frames - ), "Received unexpected number of frames." + header, frames, klass = _decode_type( + cls, header, frames, is_valid_class=issubclass + ) klass = pickle.loads(header["type-serialized"]) return klass(header["precision"], header["scale"]) @@ -566,13 +595,7 @@ def serialize(self) -> Tuple[dict, list]: @classmethod def deserialize(cls, header: dict, frames: list): - assert header["frame_count"] == len( - frames - ), "Received unexpected number of frames." - klass = pickle.loads(header["type-serialized"]) - assert ( - klass == cls - ), "Header-encoded type does not match reconstructing type" + header, frames, klass = _decode_type(cls, header, frames) klass = pickle.loads(header["type-serialized"]) subtype, closed = pickle.loads(header["fields"]) return klass(subtype, closed=closed)