Skip to content

Commit

Permalink
chore: improve save/load methods for encrypted data-frames (#582)
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft authored Apr 4, 2024
1 parent 41673fe commit c1f4b78
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
14 changes: 7 additions & 7 deletions src/concrete/ml/pandas/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,28 @@ def deserialize_elementwise(array: numpy.ndarray) -> numpy.ndarray:
return numpy.vectorize(deserialize_value)(array)


def serialize_evaluation_keys(evaluation_keys: fhe.EvaluationKeys) -> str:
"""Serialize the evaluation keys into a string of hexadecimal numbers.
def serialize_evaluation_keys(evaluation_keys: fhe.EvaluationKeys) -> bytes:
"""Serialize the evaluation keys into bytes.
Args:
evaluation_keys (fhe.EvaluationKeys): The evaluation keys to serialize.
Returns:
str: The serialized evaluation keys as a string of hexadecimal numbers.
bytes: The serialized evaluation keys.
"""
return serialize_value(evaluation_keys)
return evaluation_keys.serialize()


def deserialize_evaluation_keys(serialized_evaluation_keys: str) -> fhe.EvaluationKeys:
def deserialize_evaluation_keys(serialized_evaluation_keys: bytes) -> fhe.EvaluationKeys:
"""Deserialize the evaluation keys.
Args:
serialized_evaluation_keys (str): The evaluation keys to deserialize.
serialized_evaluation_keys (bytes): The evaluation keys to deserialize.
Returns:
fhe.EvaluationKeys: The deserialized evaluation keys.
"""
return fhe.EvaluationKeys.deserialize(bytes.fromhex(serialized_evaluation_keys))
return fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)


def slice_hex_str(hex_str: str, n: int = 10) -> str:
Expand Down
43 changes: 30 additions & 13 deletions src/concrete/ml/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from pathlib import Path
from typing import Dict, Hashable, List, Optional, Sequence, Tuple, Union
from zipfile import ZIP_STORED, ZipFile

import numpy
import pandas
Expand Down Expand Up @@ -254,11 +255,12 @@ def merge(

return joined_df

def _to_dict(self) -> Dict:
"""Serialize the encrypted data-frame as a dictionary.
def _to_dict_and_eval_keys(self) -> Tuple[Dict, fhe.EvaluationKeys]:
"""Serialize the encrypted data-frame as a dictionary and evaluations keys.
Returns:
Dict: The serialized data-frame.
fhe.EvaluationKeys: The serialized evaluations keys.
"""
# Serialize encrypted values element-wise
encrypted_values = serialize_elementwise(self._encrypted_values)
Expand All @@ -273,20 +275,20 @@ def _to_dict(self) -> Dict:
output_dict = {
"encrypted_values": encrypted_values.tolist(),
"encrypted_nan": encrypted_nan,
"evaluation_keys": evaluation_keys,
"column_names": self._column_names,
"dtype_mappings": self._dtype_mappings,
"api_version": self._api_version,
}

return output_dict
return output_dict, evaluation_keys

@classmethod
def _from_dict(cls, dict_to_load: Dict):
"""Load a serialized encrypted data-frame from a dictionary.
def _from_dict_and_eval_keys(cls, dict_to_load: Dict, evaluation_keys: fhe.EvaluationKeys):
"""Load a serialized encrypted data-frame from a dictionary and evaluations keys.
Args:
dict_to_load (Dict): The serialized encrypted data-frame.
evaluation_keys (fhe.EvaluationKeys): The serialized evaluations keys.
Returns:
EncryptedDataFrame: The loaded encrypted data-frame.
Expand All @@ -295,7 +297,7 @@ def _from_dict(cls, dict_to_load: Dict):
encrypted_values = deserialize_elementwise(dict_to_load["encrypted_values"])
encrypted_nan = deserialize_value(dict_to_load["encrypted_nan"])

evaluation_keys = deserialize_evaluation_keys(dict_to_load["evaluation_keys"])
evaluation_keys = deserialize_evaluation_keys(evaluation_keys)

column_names = dict_to_load["column_names"]
dtype_mappings = dict_to_load["dtype_mappings"]
Expand All @@ -318,9 +320,16 @@ def save(self, path: Union[Path, str]):
"""
path = Path(path)

encrypted_df_dict = self._to_dict()
with path.open("w", encoding="utf-8") as file:
json.dump(encrypted_df_dict, file)
if path.suffix != ".zip":
path = path.with_suffix(".zip")

encrypted_df_dict, evaluation_keys = self._to_dict_and_eval_keys()

encrypted_df_json_bytes = json.dumps(encrypted_df_dict).encode(encoding="utf-8")

with ZipFile(path, "w", compression=ZIP_STORED, allowZip64=True) as zip_file:
zip_file.writestr("encrypted_dataframe.json", encrypted_df_json_bytes)
zip_file.writestr("evaluation_keys", evaluation_keys)

@classmethod
def load(cls, path: Union[Path, str]):
Expand All @@ -334,7 +343,15 @@ def load(cls, path: Union[Path, str]):
"""
path = Path(path)

with path.open("r", encoding="utf-8") as file:
encrypted_df_dict = json.load(file)
if path.suffix != ".zip":
path = path.with_suffix(".zip")

with ZipFile(path, "r", compression=ZIP_STORED, allowZip64=True) as zip_file:
with zip_file.open("encrypted_dataframe.json") as encrypted_df_json_file:
encrypted_df_json_bytes = encrypted_df_json_file.read()
encrypted_df_dict = json.loads(encrypted_df_json_bytes)

with zip_file.open("evaluation_keys") as evaluation_keys_file:
evaluation_keys = evaluation_keys_file.read()

return cls._from_dict(encrypted_df_dict)
return cls._from_dict_and_eval_keys(encrypted_df_dict, evaluation_keys)

0 comments on commit c1f4b78

Please sign in to comment.