Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Empty cache #491

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions executorlib/interactive/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,16 @@ def _execute_task_with_cache(
data_dict["output"] = future.result()
dump(file_name=file_name, data_dict=data_dict)
else:
_, result = get_output(file_name=file_name)
future = task_dict["future"]
future.set_result(result)
future_queue.task_done()
exe_flag, result = get_output(file_name=file_name)
if exe_flag:
future = task_dict["future"]
future.set_result(result)
future_queue.task_done()
else:
_execute_task(
interface=interface,
task_dict=task_dict,
future_queue=future_queue,
)
data_dict["output"] = future.result()
dump(file_name=file_name, data_dict=data_dict)
Comment on lines +635 to +647
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for cache operations and implement atomic file operations.

The cache validation logic needs improvement in several areas:

  1. File operations (dump, get_output) should have proper error handling
  2. Cache file creation/updates should be atomic to prevent race conditions
  3. Invalid cache files should be cleaned up to prevent disk space issues

Consider implementing these improvements:

-        exe_flag, result = get_output(file_name=file_name)
-        if exe_flag:
-            future = task_dict["future"]
-            future.set_result(result)
-            future_queue.task_done()
-        else:
-            _execute_task(
-                interface=interface,
-                task_dict=task_dict,
-                future_queue=future_queue,
-            )
-            data_dict["output"] = future.result()
-            dump(file_name=file_name, data_dict=data_dict)
+        try:
+            exe_flag, result = get_output(file_name=file_name)
+            if exe_flag:
+                future = task_dict["future"]
+                future.set_result(result)
+                future_queue.task_done()
+            else:
+                # Remove invalid cache file
+                os.remove(file_name)
+                _execute_task(
+                    interface=interface,
+                    task_dict=task_dict,
+                    future_queue=future_queue,
+                )
+                data_dict["output"] = future.result()
+                # Use temporary file for atomic write
+                temp_file = file_name + '.tmp'
+                dump(file_name=temp_file, data_dict=data_dict)
+                os.replace(temp_file, file_name)
+        except Exception as e:
+            future.set_exception(e)
+            future_queue.task_done()
+            raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
exe_flag, result = get_output(file_name=file_name)
if exe_flag:
future = task_dict["future"]
future.set_result(result)
future_queue.task_done()
else:
_execute_task(
interface=interface,
task_dict=task_dict,
future_queue=future_queue,
)
data_dict["output"] = future.result()
dump(file_name=file_name, data_dict=data_dict)
try:
exe_flag, result = get_output(file_name=file_name)
if exe_flag:
future = task_dict["future"]
future.set_result(result)
future_queue.task_done()
else:
# Remove invalid cache file
os.remove(file_name)
_execute_task(
interface=interface,
task_dict=task_dict,
future_queue=future_queue,
)
data_dict["output"] = future.result()
# Use temporary file for atomic write
temp_file = file_name + '.tmp'
dump(file_name=temp_file, data_dict=data_dict)
os.replace(temp_file, file_name)
except Exception as e:
future.set_exception(e)
future_queue.task_done()
raise

11 changes: 7 additions & 4 deletions executorlib/standalone/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def dump(file_name: str, data_dict: dict) -> None:
with h5py.File(file_name, "a") as fname:
for data_key, data_value in data_dict.items():
if data_key in group_dict.keys():
fname.create_dataset(
name="/" + group_dict[data_key],
data=np.void(cloudpickle.dumps(data_value)),
)
try:
fname.create_dataset(
name="/" + group_dict[data_key],
data=np.void(cloudpickle.dumps(data_value)),
)
except ValueError:
pass


def load(file_name: str) -> dict:
Expand Down
45 changes: 45 additions & 0 deletions tests/test_executor_backend_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from executorlib import Executor
from executorlib.standalone.serialize import cloudpickle_register

try:
import h5py

skip_h5py_test = False
except ImportError:
skip_h5py_test = True

skip_mpi4py_test = importlib.util.find_spec("mpi4py") is None

Expand All @@ -14,6 +20,11 @@ def calc(i):
return i


def calc_sleep(i):
time.sleep(i)
return i


def mpi_funct(i):
from mpi4py import MPI

Expand Down Expand Up @@ -92,6 +103,40 @@ class TestExecutorBackendCache(unittest.TestCase):
def tearDown(self):
shutil.rmtree("./cache")

@unittest.skipIf(
skip_h5py_test, "h5py is not installed, so the h5py tests are skipped."
)
def test_executor_cache_bypass(self):
with Executor(
max_workers=2,
backend="local",
block_allocation=True,
cache_directory="./cache",
) as exe:
cloudpickle_register(ind=1)
time_1 = time.time()
fs_1 = exe.submit(calc_sleep, 1)
fs_2 = exe.submit(calc_sleep, 1)
self.assertEqual(fs_1.result(), 1)
self.assertTrue(fs_1.done())
time_2 = time.time()
self.assertEqual(fs_2.result(), 1)
self.assertTrue(fs_2.done())
time_3 = time.time()
self.assertTrue(time_2 - time_1 > 1)
self.assertTrue(time_3 - time_1 > 1)
time_4 = time.time()
fs_3 = exe.submit(calc_sleep, 1)
fs_4 = exe.submit(calc_sleep, 1)
self.assertEqual(fs_3.result(), 1)
self.assertTrue(fs_3.done())
time_5 = time.time()
self.assertEqual(fs_4.result(), 1)
self.assertTrue(fs_4.done())
time_6 = time.time()
self.assertTrue(time_5 - time_4 < 1)
self.assertTrue(time_6 - time_4 < 1)

@unittest.skipIf(
skip_mpi4py_test, "mpi4py is not installed, so the mpi4py tests are skipped."
)
Expand Down
Loading