Skip to content

Commit

Permalink
remove race condition for flash loader tests and remove additional ge…
Browse files Browse the repository at this point in the history
…nerated buffer files
  • Loading branch information
rettigl committed Oct 12, 2023
1 parent 2644b41 commit 8196329
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions tests/loader/test_loaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test cases for loaders used to load dataframes
"""
import os
from copy import deepcopy
from importlib.util import find_spec
from pathlib import Path
from typing import cast
Expand Down Expand Up @@ -86,6 +87,15 @@ def test_if_loaders_are_children_of_base_loader(loader: BaseLoader):
def test_has_correct_read_dataframe_func(loader: BaseLoader, read_type: str):
"""Test if all loaders have a valid read function implemented"""
assert callable(loader.read_dataframe)

# Fix for race condition during parallel testing
if loader.__name__ == "flash":
config = deepcopy(loader._config)
config["core"]["paths"]["data_parquet_dir"] = (
config["core"]["paths"]["data_parquet_dir"] + f"_{read_type}"
)
loader = get_loader(loader_name="flash", config=config)

if loader.__name__ != "BaseLoader":
assert hasattr(loader, "files")
assert hasattr(loader, "supported_file_types")
Expand Down Expand Up @@ -165,6 +175,15 @@ def test_get_count_rate(loader: BaseLoader):
Args:
loader (BaseLoader): the loader object to test
"""

# Fix for race condition during parallel testing
if loader.__name__ == "flash":
config = deepcopy(loader._config)
config["core"]["paths"]["data_parquet_dir"] = (
config["core"]["paths"]["data_parquet_dir"] + "_count_rate"
)
loader = get_loader(loader_name="flash", config=config)

if loader.__name__ != "BaseLoader":
loader_name = get_loader_name_from_loader_object(loader)
input_folder = os.path.join(test_data_dir, "loader", loader_name)
Expand All @@ -176,12 +195,23 @@ def test_get_count_rate(loader: BaseLoader):
)
loaded_time, loaded_countrate = loader.get_count_rate()
if loaded_time is None and loaded_countrate is None:
if loader.__name__ == "flash":
loader = cast(FlashLoader, loader)
_, parquet_data_dir = loader.initialize_paths()
for file in os.listdir(Path(parquet_data_dir, "buffer")):
os.remove(Path(parquet_data_dir, "buffer", file))
pytest.skip("Not implemented")
assert len(loaded_time) == len(loaded_countrate)
loaded_time2, loaded_countrate2 = loader.get_count_rate(fids=[0])
assert len(loaded_time2) == len(loaded_countrate2)
assert len(loaded_time2) < len(loaded_time)

if loader.__name__ == "flash":
loader = cast(FlashLoader, loader)
_, parquet_data_dir = loader.initialize_paths()
for file in os.listdir(Path(parquet_data_dir, "buffer")):
os.remove(Path(parquet_data_dir, "buffer", file))


@pytest.mark.parametrize("loader", get_all_loaders())
def test_get_elapsed_time(loader: BaseLoader):
Expand All @@ -190,6 +220,15 @@ def test_get_elapsed_time(loader: BaseLoader):
Args:
loader (BaseLoader): the loader object to test
"""

# Fix for race condition during parallel testing
if loader.__name__ == "flash":
config = deepcopy(loader._config)
config["core"]["paths"]["data_parquet_dir"] = (
config["core"]["paths"]["data_parquet_dir"] + "_elapsed_time"
)
loader = get_loader(loader_name="flash", config=config)

if loader.__name__ != "BaseLoader":
loader_name = get_loader_name_from_loader_object(loader)
input_folder = os.path.join(test_data_dir, "loader", loader_name)
Expand All @@ -201,12 +240,23 @@ def test_get_elapsed_time(loader: BaseLoader):
)
elapsed_time = loader.get_elapsed_time()
if elapsed_time is None:
if loader.__name__ == "flash":
loader = cast(FlashLoader, loader)
_, parquet_data_dir = loader.initialize_paths()
for file in os.listdir(Path(parquet_data_dir, "buffer")):
os.remove(Path(parquet_data_dir, "buffer", file))
pytest.skip("Not implemented")
assert elapsed_time > 0
elapsed_time2 = loader.get_elapsed_time(fids=[0])
assert elapsed_time2 > 0
assert elapsed_time > elapsed_time2

if loader.__name__ == "flash":
loader = cast(FlashLoader, loader)
_, parquet_data_dir = loader.initialize_paths()
for file in os.listdir(Path(parquet_data_dir, "buffer")):
os.remove(Path(parquet_data_dir, "buffer", file))


def test_mpes_timestamps():
"""Function to test if the timestamps are loaded correctly"""
Expand Down

0 comments on commit 8196329

Please sign in to comment.