diff --git a/nvflare/app_common/state_persistors/__init__.py b/nvflare/app_common/state_persistors/__init__.py index e69de29bb2d..2b8f6c7e874 100644 --- a/nvflare/app_common/state_persistors/__init__.py +++ b/nvflare/app_common/state_persistors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/state_persistors/storage_state_persistor.py b/nvflare/app_common/state_persistors/storage_state_persistor.py index bb8a6e833ae..abadabccb8f 100644 --- a/nvflare/app_common/state_persistors/storage_state_persistor.py +++ b/nvflare/app_common/state_persistors/storage_state_persistor.py @@ -24,7 +24,6 @@ class StorageStatePersistor(StatePersistor): def __init__(self, storage: StorageSpec, location: str): self.storage = storage - self.location = location if not os.path.isabs(location): @@ -36,7 +35,6 @@ def save(self, snapshot: FLSnapshot) -> str: snapshot: FLSnapshot object Returns: storage location """ - # snapshot_uri_timestamp = "snapshot-" + datetime.datetime.now().strftime("%Y-%m-%d,%H:%M:%S") self.storage.create_object( uri=self.location, data=pickle.dumps(snapshot), meta={}, overwrite_existing=True ) diff --git a/nvflare/app_common/storages/filesystem_storage.py b/nvflare/app_common/storages/filesystem_storage.py index c1f97d1f55c..808b51f0e79 100644 --- a/nvflare/app_common/storages/filesystem_storage.py +++ b/nvflare/app_common/storages/filesystem_storage.py @@ -56,20 +56,20 @@ def _write(self, path: str, content): try: Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) with open(path + "_tmp", "wb") as f: - f.write(pickle.dumps(content)) + f.write(content) f.flush() os.fsync(f.fileno()) - if os.path.isfile(path): - os.remove(path) except Exception as e: + if os.path.isfile(path + "_tmp"): + os.remove(path + "_tmp") raise IOError("failed to write content: {}".format(e)) os.rename(path + "_tmp", path) - def _read(self, path: str) -> object: + def _read(self, path: str) -> bytes: try: with open(path, "rb") as f: - content = pickle.load(f) + content = f.read() except Exception as e: raise IOError("failed to read content: {}".format(e)) @@ -114,7 +114,11 @@ def create_object(self, uri: str, data: ByteString, meta: dict, overwrite_existi meta_path = os.path.join(full_uri, "meta") self._write(data_path + "_tmp", data) - self._write(meta_path, meta) + try: + self._write(meta_path, pickle.dumps(meta)) + except Exception as e: + os.remove(data_path + "_tmp") + raise e os.rename(data_path + "_tmp", data_path) def update_meta(self, uri: str, meta: dict, replace: bool): @@ -140,11 +144,11 @@ def update_meta(self, uri: str, meta: dict, replace: bool): raise Exception("object {} does not exist".format(uri)) if replace: - self._write(os.path.join(full_uri, "meta"), meta) + self._write(os.path.join(full_uri, "meta"), pickle.dumps(meta)) else: prev_meta = self.get_meta(uri) prev_meta.update(meta) - self._write(os.path.join(full_uri, "meta"), prev_meta) + self._write(os.path.join(full_uri, "meta"), pickle.dumps(prev_meta)) def update_data(self, uri: str, data: ByteString): """Update the data info of the specified object @@ -201,7 +205,7 @@ def get_meta(self, uri: str) -> dict: if not self._object_exists(full_uri): raise Exception("object {} does not exist".format(uri)) - return self._read(os.path.join(full_uri, "meta")) + return pickle.loads(self._read(os.path.join(full_uri, "meta"))) def get_full_meta(self, uri: str) -> dict: """Get full meta info of the specified object