diff --git a/executorlib/standalone/hdf.py b/executorlib/standalone/hdf.py index c0dd0609..f6dd702b 100644 --- a/executorlib/standalone/hdf.py +++ b/executorlib/standalone/hdf.py @@ -1,10 +1,21 @@ -from typing import Optional, Tuple +import os +from typing import Optional, Tuple, List import cloudpickle import h5py import numpy as np +group_dict = { + "fn": "function", + "args": "input_args", + "kwargs": "input_kwargs", + "output": "output", + "runtime": "runtime", + "queue_id": "queue_id", +} + + def dump(file_name: str, data_dict: dict) -> None: """ Dump data dictionary into HDF5 file @@ -13,14 +24,6 @@ def dump(file_name: str, data_dict: dict) -> None: file_name (str): file name of the HDF5 file as absolute path data_dict (dict): dictionary containing the python function to be executed {"fn": ..., "args": (), "kwargs": {}} """ - group_dict = { - "fn": "function", - "args": "input_args", - "kwargs": "input_kwargs", - "output": "output", - "runtime": "runtime", - "queue_id": "queue_id", - } with h5py.File(file_name, "a") as fname: for data_key, data_value in data_dict.items(): if data_key in group_dict.keys(): @@ -97,3 +100,16 @@ def get_queue_id(file_name: str) -> Optional[int]: return cloudpickle.loads(np.void(hdf["/queue_id"])) else: return None + + +def get_cache_data(cache_directory: str) -> List[dict]: + file_lst = [] + for file_name in os.listdir(cache_directory): + with h5py.File(os.path.join(cache_directory, file_name), "r") as hdf: + file_content_dict = { + key: cloudpickle.loads(np.void(hdf["/" + key])) + for key in group_dict.values() if key in hdf + } + file_content_dict["filename"] = file_name + file_lst.append(file_content_dict) + return file_lst diff --git a/tests/test_cache_executor_interactive.py b/tests/test_cache_executor_interactive.py new file mode 100644 index 00000000..125b3606 --- /dev/null +++ b/tests/test_cache_executor_interactive.py @@ -0,0 +1,32 @@ +import os +import shutil +import unittest + +from executorlib import Executor + +try: + + from executorlib.standalone.hdf import get_cache_data + + skip_h5py_test = False +except ImportError: + skip_h5py_test = True + + +@unittest.skipIf( + skip_h5py_test, "h5py is not installed, so the h5io tests are skipped." +) +class TestCacheFunctions(unittest.TestCase): + def test_cache_data(self): + cache_directory = "./cache" + with Executor(backend="local", cache_directory=cache_directory) as exe: + future_lst = [exe.submit(sum, [i, i]) for i in range(1, 4)] + result_lst = [f.result() for f in future_lst] + + cache_lst = get_cache_data(cache_directory=cache_directory) + self.assertEqual(sum([c["output"] for c in cache_lst]), sum(result_lst)) + self.assertEqual(sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)) + + def tearDown(self): + if os.path.exists("cache"): + shutil.rmtree("cache") \ No newline at end of file