Skip to content

Commit

Permalink
dtypes: Refactor deserialize error checking
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed May 5, 2022
1 parent 573395e commit 6300e6d
Showing 1 changed file with 55 additions and 32 deletions.
87 changes: 55 additions & 32 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

import decimal
import operator
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -68,6 +69,52 @@ def dtype(arbitrary):
)


def _decode_type(
cls: Type,
header: dict,
frames: list,
is_valid_class: Callable[[Type, Type], bool] = operator.is_,
) -> Tuple[dict, list, Type]:
"""Decode metadata-encoded type and check validity
Parameters
----------
cls : type
class performing deserialization
header : dict
metadata for deserialization
frames : list
buffers containing data for deserialization
is_valid_class : Callable
function to call to check if the encoded class type is valid for
serialization by `cls` (default is to check type equality), called
as `is_valid_class(decoded_class, cls)`.
Returns
-------
tuple
Tuple of validated headers, frames, and the decoded class
constructor.
Raises
------
AssertionError
if the number of frames doesn't match the count encoded in the
headers, or `is_valid_class` is not true.
"""
if header["frame_count"] != len(frames):
raise AssertionError(
f"Deserialization expected {header['frame_count']} frames, "
f"but received {len(frames)}."
)
klass = pickle.loads(header["type-serialized"])
if not is_valid_class(klass, cls):
raise AssertionError(
f"Header-encoded {klass=} does not match decoding {cls=}."
)
return header, frames, klass


class _BaseDtype(ExtensionDtype, Serializable):
# Base type for all cudf-specific dtypes
pass
Expand Down Expand Up @@ -174,13 +221,7 @@ def serialize(self):

@classmethod
def deserialize(cls, header, frames):
assert header["frame_count"] == len(
frames
), "Received unexpected number of frames."
klass = pickle.loads(header["type-serialized"])
assert (
klass == cls
), "Header-encoded type does not match reconstructing type"
header, frames, klass = _decode_type(cls, header, frames)
ordered = header["ordered"]
categories_header = header["categories"]
categories_frames = frames
Expand Down Expand Up @@ -266,13 +307,7 @@ def serialize(self) -> Tuple[dict, list]:

@classmethod
def deserialize(cls, header: dict, frames: list):
assert header["frame_count"] == len(
frames
), "Received unexpected number of frames."
klass = pickle.loads(header["type-serialized"])
assert (
klass == cls
), "Header-encoded type does not match reconstructing type"
header, frames, klass = _decode_type(cls, header, frames)
if isinstance(header["element-type"], dict):
element_type = pickle.loads(
header["element-type"]["type-serialized"]
Expand Down Expand Up @@ -356,13 +391,7 @@ def serialize(self) -> Tuple[dict, list]:

@classmethod
def deserialize(cls, header: dict, frames: list):
assert header["frame_count"] == len(
frames
), "Received unexpected number of frames."
klass = pickle.loads(header["type-serialized"])
assert (
klass == cls
), "Header-encoded type does not match reconstructing type"
header, frames, klass = _decode_type(cls, header, frames)
fields = {}
for k, dtype in header["fields"].items():
if isinstance(dtype, tuple):
Expand Down Expand Up @@ -479,9 +508,9 @@ def serialize(self) -> Tuple[dict, list]:

@classmethod
def deserialize(cls, header: dict, frames: list):
assert header["frame_count"] == len(
frames
), "Received unexpected number of frames."
header, frames, klass = _decode_type(
cls, header, frames, is_valid_class=issubclass
)
klass = pickle.loads(header["type-serialized"])
return klass(header["precision"], header["scale"])

Expand Down Expand Up @@ -566,13 +595,7 @@ def serialize(self) -> Tuple[dict, list]:

@classmethod
def deserialize(cls, header: dict, frames: list):
assert header["frame_count"] == len(
frames
), "Received unexpected number of frames."
klass = pickle.loads(header["type-serialized"])
assert (
klass == cls
), "Header-encoded type does not match reconstructing type"
header, frames, klass = _decode_type(cls, header, frames)
klass = pickle.loads(header["type-serialized"])
subtype, closed = pickle.loads(header["fields"])
return klass(subtype, closed=closed)
Expand Down

0 comments on commit 6300e6d

Please sign in to comment.