diff --git a/tests/loader/test_loaders.py b/tests/loader/test_loaders.py index 1d744088..d7a0fb5e 100644 --- a/tests/loader/test_loaders.py +++ b/tests/loader/test_loaders.py @@ -1,6 +1,7 @@ """Test cases for loaders used to load dataframes """ import os +from copy import deepcopy from importlib.util import find_spec from pathlib import Path from typing import cast @@ -86,6 +87,15 @@ def test_if_loaders_are_children_of_base_loader(loader: BaseLoader): def test_has_correct_read_dataframe_func(loader: BaseLoader, read_type: str): """Test if all loaders have a valid read function implemented""" assert callable(loader.read_dataframe) + + # Fix for race condition during parallel testing + if loader.__name__ == "flash": + config = deepcopy(loader._config) + config["core"]["paths"]["data_parquet_dir"] = ( + config["core"]["paths"]["data_parquet_dir"] + f"_{read_type}" + ) + loader = get_loader(loader_name="flash", config=config) + if loader.__name__ != "BaseLoader": assert hasattr(loader, "files") assert hasattr(loader, "supported_file_types") @@ -165,6 +175,15 @@ def test_get_count_rate(loader: BaseLoader): Args: loader (BaseLoader): the loader object to test """ + + # Fix for race condition during parallel testing + if loader.__name__ == "flash": + config = deepcopy(loader._config) + config["core"]["paths"]["data_parquet_dir"] = ( + config["core"]["paths"]["data_parquet_dir"] + "_count_rate" + ) + loader = get_loader(loader_name="flash", config=config) + if loader.__name__ != "BaseLoader": loader_name = get_loader_name_from_loader_object(loader) input_folder = os.path.join(test_data_dir, "loader", loader_name) @@ -176,12 +195,23 @@ def test_get_count_rate(loader: BaseLoader): ) loaded_time, loaded_countrate = loader.get_count_rate() if loaded_time is None and loaded_countrate is None: + if loader.__name__ == "flash": + loader = cast(FlashLoader, loader) + _, parquet_data_dir = loader.initialize_paths() + for file in os.listdir(Path(parquet_data_dir, "buffer")): + os.remove(Path(parquet_data_dir, "buffer", file)) pytest.skip("Not implemented") assert len(loaded_time) == len(loaded_countrate) loaded_time2, loaded_countrate2 = loader.get_count_rate(fids=[0]) assert len(loaded_time2) == len(loaded_countrate2) assert len(loaded_time2) < len(loaded_time) + if loader.__name__ == "flash": + loader = cast(FlashLoader, loader) + _, parquet_data_dir = loader.initialize_paths() + for file in os.listdir(Path(parquet_data_dir, "buffer")): + os.remove(Path(parquet_data_dir, "buffer", file)) + @pytest.mark.parametrize("loader", get_all_loaders()) def test_get_elapsed_time(loader: BaseLoader): @@ -190,6 +220,15 @@ def test_get_elapsed_time(loader: BaseLoader): Args: loader (BaseLoader): the loader object to test """ + + # Fix for race condition during parallel testing + if loader.__name__ == "flash": + config = deepcopy(loader._config) + config["core"]["paths"]["data_parquet_dir"] = ( + config["core"]["paths"]["data_parquet_dir"] + "_elapsed_time" + ) + loader = get_loader(loader_name="flash", config=config) + if loader.__name__ != "BaseLoader": loader_name = get_loader_name_from_loader_object(loader) input_folder = os.path.join(test_data_dir, "loader", loader_name) @@ -201,12 +240,23 @@ def test_get_elapsed_time(loader: BaseLoader): ) elapsed_time = loader.get_elapsed_time() if elapsed_time is None: + if loader.__name__ == "flash": + loader = cast(FlashLoader, loader) + _, parquet_data_dir = loader.initialize_paths() + for file in os.listdir(Path(parquet_data_dir, "buffer")): + os.remove(Path(parquet_data_dir, "buffer", file)) pytest.skip("Not implemented") assert elapsed_time > 0 elapsed_time2 = loader.get_elapsed_time(fids=[0]) assert elapsed_time2 > 0 assert elapsed_time > elapsed_time2 + if loader.__name__ == "flash": + loader = cast(FlashLoader, loader) + _, parquet_data_dir = loader.initialize_paths() + for file in os.listdir(Path(parquet_data_dir, "buffer")): + os.remove(Path(parquet_data_dir, "buffer", file)) + def test_mpes_timestamps(): """Function to test if the timestamps are loaded correctly"""