diff --git a/executorlib/interactive/shared.py b/executorlib/interactive/shared.py index 3eb79986..35611a4b 100644 --- a/executorlib/interactive/shared.py +++ b/executorlib/interactive/shared.py @@ -632,7 +632,16 @@ def _execute_task_with_cache( data_dict["output"] = future.result() dump(file_name=file_name, data_dict=data_dict) else: - _, result = get_output(file_name=file_name) - future = task_dict["future"] - future.set_result(result) - future_queue.task_done() + exe_flag, result = get_output(file_name=file_name) + if exe_flag: + future = task_dict["future"] + future.set_result(result) + future_queue.task_done() + else: + _execute_task( + interface=interface, + task_dict=task_dict, + future_queue=future_queue, + ) + data_dict["output"] = future.result() + dump(file_name=file_name, data_dict=data_dict) diff --git a/executorlib/standalone/hdf.py b/executorlib/standalone/hdf.py index 9e8c8798..669c2a09 100644 --- a/executorlib/standalone/hdf.py +++ b/executorlib/standalone/hdf.py @@ -22,10 +22,13 @@ def dump(file_name: str, data_dict: dict) -> None: with h5py.File(file_name, "a") as fname: for data_key, data_value in data_dict.items(): if data_key in group_dict.keys(): - fname.create_dataset( - name="/" + group_dict[data_key], - data=np.void(cloudpickle.dumps(data_value)), - ) + try: + fname.create_dataset( + name="/" + group_dict[data_key], + data=np.void(cloudpickle.dumps(data_value)), + ) + except ValueError: + pass def load(file_name: str) -> dict: diff --git a/tests/test_executor_backend_mpi.py b/tests/test_executor_backend_mpi.py index 9a002136..aa4a4b90 100644 --- a/tests/test_executor_backend_mpi.py +++ b/tests/test_executor_backend_mpi.py @@ -6,6 +6,12 @@ from executorlib import Executor from executorlib.standalone.serialize import cloudpickle_register +try: + import h5py + + skip_h5py_test = False +except ImportError: + skip_h5py_test = True skip_mpi4py_test = importlib.util.find_spec("mpi4py") is None @@ -14,6 +20,11 @@ def calc(i): return i +def calc_sleep(i): + time.sleep(i) + return i + + def mpi_funct(i): from mpi4py import MPI @@ -92,6 +103,40 @@ class TestExecutorBackendCache(unittest.TestCase): def tearDown(self): shutil.rmtree("./cache") + @unittest.skipIf( + skip_h5py_test, "h5py is not installed, so the h5py tests are skipped." + ) + def test_executor_cache_bypass(self): + with Executor( + max_workers=2, + backend="local", + block_allocation=True, + cache_directory="./cache", + ) as exe: + cloudpickle_register(ind=1) + time_1 = time.time() + fs_1 = exe.submit(calc_sleep, 1) + fs_2 = exe.submit(calc_sleep, 1) + self.assertEqual(fs_1.result(), 1) + self.assertTrue(fs_1.done()) + time_2 = time.time() + self.assertEqual(fs_2.result(), 1) + self.assertTrue(fs_2.done()) + time_3 = time.time() + self.assertTrue(time_2 - time_1 > 1) + self.assertTrue(time_3 - time_1 > 1) + time_4 = time.time() + fs_3 = exe.submit(calc_sleep, 1) + fs_4 = exe.submit(calc_sleep, 1) + self.assertEqual(fs_3.result(), 1) + self.assertTrue(fs_3.done()) + time_5 = time.time() + self.assertEqual(fs_4.result(), 1) + self.assertTrue(fs_4.done()) + time_6 = time.time() + self.assertTrue(time_5 - time_4 < 1) + self.assertTrue(time_6 - time_4 < 1) + @unittest.skipIf( skip_mpi4py_test, "mpi4py is not installed, so the mpi4py tests are skipped." ) diff --git a/tests/test_local_executor.py b/tests/test_local_executor.py index 29c5e72b..71aa6991 100644 --- a/tests/test_local_executor.py +++ b/tests/test_local_executor.py @@ -1,7 +1,9 @@ +import os from concurrent.futures import CancelledError, Future import importlib.util from queue import Queue from time import sleep +import shutil import unittest import numpy as np @@ -16,6 +18,12 @@ from executorlib.standalone.interactive.backend import call_funct from executorlib.standalone.serialize import cloudpickle_register +try: + import h5py + + skip_h5py_test = False +except ImportError: + skip_h5py_test = True skip_mpi4py_test = importlib.util.find_spec("mpi4py") is None @@ -32,6 +40,11 @@ def echo_funct(i): return i +def calc_sleep(i): + sleep(i) + return i + + def get_global(memory=None): return memory @@ -473,3 +486,41 @@ def test_execute_task_parallel(self): ) self.assertEqual(f.result(), [np.array(4), np.array(4)]) q.join() + + +class TestFuturePoolCache(unittest.TestCase): + def tearDown(self): + shutil.rmtree("./cache") + + @unittest.skipIf( + skip_h5py_test, "h5py is not installed, so the h5py tests are skipped." + ) + def test_execute_task_cache(self): + f1 = Future() + f2 = Future() + q1 = Queue() + q2 = Queue() + q1.put({"fn": calc_sleep, "args": (), "kwargs": {"i": 1}, "future": f1}) + q1.put({"shutdown": True, "wait": True}) + q2.put({"fn": calc_sleep, "args": (), "kwargs": {"i": 1}, "future": f2}) + q2.put({"shutdown": True, "wait": True}) + cloudpickle_register(ind=1) + execute_parallel_tasks( + future_queue=q1, + cores=1, + openmpi_oversubscribe=False, + spawner=MpiExecSpawner, + cache_directory="./cache", + ) + sleep(0.5) + execute_parallel_tasks( + future_queue=q2, + cores=1, + openmpi_oversubscribe=False, + spawner=MpiExecSpawner, + cache_directory="./cache", + ) + self.assertEqual(f1.result(), 1) + self.assertEqual(f2.result(), 1) + q1.join() + q2.join()